From 88b3d3371b668e2e3f89218629d764a8e5868f0c Mon Sep 17 00:00:00 2001 From: Rong Rong Date: Fri, 11 Dec 2020 08:08:40 -0800 Subject: [PATCH 01/33] add additional arm64 checker in cmake files (#48952) Summary: tentatively fixes https://github.com/pytorch/pytorch/issues/48873 Pull Request resolved: https://github.com/pytorch/pytorch/pull/48952 Reviewed By: H-Huang Differential Revision: D25463266 Pulled By: walterddr fbshipit-source-id: 40afefffe8ab98ae7261c770316cb9c25225285f --- aten/src/ATen/native/quantized/cpu/qnnpack/CMakeLists.txt | 2 +- cmake/External/nnpack.cmake | 2 +- third_party/NNPACK | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/CMakeLists.txt b/aten/src/ATen/native/quantized/cpu/qnnpack/CMakeLists.txt index 9f8cb6d9ed09..99bf8ba07074 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/CMakeLists.txt +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/CMakeLists.txt @@ -271,7 +271,7 @@ if(CMAKE_SYSTEM_PROCESSOR MATCHES "^armv[5-8]" OR IOS_ARCH MATCHES "^armv7") set_property(SOURCE ${PYTORCH_QNNPACK_AARCH32_ASM_UKERNELS} APPEND_STRING PROPERTY COMPILE_FLAGS " -arch ${IOS_ARCH} ") endif() endif() -if(CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64" OR IOS_ARCH MATCHES "^arm64.*") +if(CMAKE_SYSTEM_PROCESSOR MATCHES "^(aarch64|arm64)$" OR IOS_ARCH MATCHES "^arm64.*") set_property(SOURCE ${PYTORCH_QNNPACK_ARM_NEON_UKERNELS} APPEND_STRING PROPERTY COMPILE_FLAGS " -O2 ") if(IOS) set_property(SOURCE ${PYTORCH_QNNPACK_AARCH64_ASM_UKERNELS} APPEND_STRING PROPERTY COMPILE_FLAGS " -arch ${IOS_ARCH} ") diff --git a/cmake/External/nnpack.cmake b/cmake/External/nnpack.cmake index 84244dc864c3..b1dcd728e690 100644 --- a/cmake/External/nnpack.cmake +++ b/cmake/External/nnpack.cmake @@ -27,7 +27,7 @@ endif() # (2) Anything but x86, x86-64, ARM, ARM64 - unsupported ############################################################################## if(CMAKE_SYSTEM_PROCESSOR) - if(NOT CMAKE_SYSTEM_PROCESSOR MATCHES "^(i686|x86_64|armv5te|armv7-a|armv7l|aarch64)$") + if(NOT CMAKE_SYSTEM_PROCESSOR MATCHES "^(i686|x86_64|armv5te|armv7-a|armv7l|arm64|aarch64)$") message(WARNING "NNPACK is not supported on ${CMAKE_SYSTEM_PROCESSOR} processors. " "The only supported architectures are x86, x86-64, ARM, and ARM64. " "Turn this warning off by USE_NNPACK=OFF.") diff --git a/third_party/NNPACK b/third_party/NNPACK index 24b55303f5cf..57616b9a0ef7 160000 --- a/third_party/NNPACK +++ b/third_party/NNPACK @@ -1 +1 @@ -Subproject commit 24b55303f5cf65d75844714513a0d1b1409809bd +Subproject commit 57616b9a0ef7b0f8e56bfe7e9738744b52fe1828 From 2bb2f641c47e2705b8d2bc9514c67764182b1462 Mon Sep 17 00:00:00 2001 From: Sam Estep Date: Fri, 11 Dec 2020 08:15:13 -0800 Subject: [PATCH 02/33] Bring fast_nvcc.py to PyTorch OSS (#48934) Summary: This PR adds `tools/fast_nvcc/fast_nvcc.py`, a mostly-transparent wrapper over `nvcc` that parallelizes compilation of CUDA files when building for multiple architectures at once. Pull Request resolved: https://github.com/pytorch/pytorch/pull/48934 Test Plan: Currently this script isn't actually used in PyTorch OSS. Coming soon! Reviewed By: walterddr Differential Revision: D25286030 Pulled By: samestep fbshipit-source-id: 971a404cf57f5694dea899a27338520d25191706 --- tools/README.md | 5 + tools/fast_nvcc/fast_nvcc.py | 463 +++++++++++++++++++++++++++++++++++ 2 files changed, 468 insertions(+) create mode 100755 tools/fast_nvcc/fast_nvcc.py diff --git a/tools/README.md b/tools/README.md index 527351d1c84a..b940d378320b 100644 --- a/tools/README.md +++ b/tools/README.md @@ -29,6 +29,11 @@ Build system pieces: * [build_libtorch.py](build_libtorch.py) - Script for building libtorch, a standalone C++ library without Python support. This build script is tested in CI. +* [fast_nvcc](fast_nvcc) - Mostly-transparent wrapper over nvcc that + parallelizes compilation when used to build CUDA files for multiple + architectures at once. + * [fast_nvcc.py](fast_nvcc/fast_nvcc.py) - Python script, entrypoint to the + fast nvcc wrapper. Developer tools which you might find useful: diff --git a/tools/fast_nvcc/fast_nvcc.py b/tools/fast_nvcc/fast_nvcc.py new file mode 100755 index 000000000000..2a8d1d731453 --- /dev/null +++ b/tools/fast_nvcc/fast_nvcc.py @@ -0,0 +1,463 @@ +#!/usr/bin/env python3 + +import argparse +import asyncio +import collections +import csv +import hashlib +import itertools +import os +import pathlib +import re +import shlex +import shutil +import subprocess +import sys +import time + + +help_msg = '''fast_nvcc [OPTION]... -- [NVCC_ARG]... + +Run the commands given by nvcc --dryrun, in parallel. + +All flags for this script itself (see the "optional arguments" section +of --help) must be passed before the first "--". Everything after that +first "--" is passed directly to nvcc, with the --dryrun argument added. + +This script only works with the "normal" execution path of nvcc, so for +instance passing --help (after "--") doesn't work since the --help +execution path doesn't compile anything, so adding --dryrun there gives +nothing in stderr. +''' +parser = argparse.ArgumentParser(help_msg) +parser.add_argument( + '--faithful', + action='store_true', + help="don't modify the commands given by nvcc (slower)", +) +parser.add_argument( + '--graph', + metavar='FILE.dot', + help='write Graphviz DOT file with execution graph', +) +parser.add_argument( + '--nvcc', + metavar='PATH', + default='nvcc', + help='path to nvcc (default is just "nvcc")', +) +parser.add_argument( + '--save', + metavar='DIR', + help='copy intermediate files from each command into DIR', +) +parser.add_argument( + '--sequential', + action='store_true', + help='sequence commands instead of using the graph (slower)', +) +parser.add_argument( + '--table', + metavar='FILE.csv', + help='write CSV with times and intermediate file sizes', +) +parser.add_argument( + '--verbose', + metavar='FILE.txt', + help='like nvcc --verbose, but expanded and into a file', +) +default_config = parser.parse_args([]) + + +# docs about temporary directories used by NVCC +url_base = 'https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html' +url_vars = f'{url_base}#keeping-intermediate-phase-files' + + +# regex for temporary file names +re_tmp = r'(? '{filename}'") + uniqueified.append(line) + return uniqueified + + +def make_rm_force(commands): + """ + Add --force to all rm commands. + """ + return [f'{c} --force' if c.startswith('rm ') else c for c in commands] + + +def print_verbose_output(*, env, commands, filename): + """ + Human-readably write nvcc --dryrun data to stderr. + """ + padding = len(str(len(commands) - 1)) + with open(filename, 'w') as f: + for name, val in env.items(): + print(f'#{" "*padding}$ {name}={val}', file=f) + for i, command in enumerate(commands): + prefix = f'{str(i).rjust(padding)}$ ' + print(f'#{prefix}{command[0]}', file=f) + for part in command[1:]: + print(f'#{" "*len(prefix)}{part}', file=f) + + +def straight_line_dependencies(commands): + """ + Return a straight-line dependency graph. + """ + return [({i - 1} if i > 0 else set()) for i in range(len(commands))] + + +def files_mentioned(command): + """ + Return fully-qualified names of all tmp files referenced by command. + """ + return [f'/tmp/{match.group(1)}' for match in re.finditer(re_tmp, command)] + + +def nvcc_data_dependencies(commands): + """ + Return a list of the set of dependencies for each command. + """ + # fatbin needs to be treated specially because while the cicc steps + # do refer to .fatbin.c files, they do so through the + # --include_file_name option, since they're generating files that + # refer to .fatbin.c file(s) that will later be created by the + # fatbinary step; so for most files, we make a data dependency from + # the later step to the earlier step, but for .fatbin.c files, the + # data dependency is sort of flipped, because the steps that use the + # files generated by cicc need to wait for the fatbinary step to + # finish first + tmp_files = {} + fatbins = collections.defaultdict(set) + graph = [] + for i, line in enumerate(commands): + deps = set() + for tmp in files_mentioned(line): + if tmp in tmp_files: + dep = tmp_files[tmp] + deps.add(dep) + if dep in fatbins: + for filename in fatbins[dep]: + if filename in tmp_files: + deps.add(tmp_files[filename]) + if tmp.endswith('.fatbin.c') and not line.startswith('fatbinary'): + fatbins[i].add(tmp) + else: + tmp_files[tmp] = i + if line.startswith('rm ') and not deps: + deps.add(i - 1) + graph.append(deps) + return graph + + +def is_weakly_connected(graph): + """ + Return true iff graph is weakly connected. + """ + neighbors = [set() for _ in graph] + for node, predecessors in enumerate(graph): + for pred in predecessors: + neighbors[pred].add(node) + neighbors[node].add(pred) + # assume nonempty graph + stack = [0] + found = {0} + while stack: + node = stack.pop() + for neighbor in neighbors[node]: + if neighbor not in found: + found.add(neighbor) + stack.append(neighbor) + return len(found) == len(graph) + + +def warn_if_not_weakly_connected(graph): + """ + Warn the user if the execution graph is not weakly connected. + """ + if not is_weakly_connected(graph): + fast_nvcc_warn('execution graph is not (weakly) connected') + + +def print_dot_graph(*, commands, graph, filename): + """ + Print a DOT file displaying short versions of the commands in graph. + """ + def name(k): + return f'"{k} {os.path.basename(commands[k][0])}"' + with open(filename, 'w') as f: + print('digraph {', file=f) + # print all nodes, in case it's disconnected + for i in range(len(graph)): + print(f' {name(i)};', file=f) + for i, deps in enumerate(graph): + for j in deps: + print(f' {name(j)} -> {name(i)};', file=f) + print('}', file=f) + + +async def run_command(command, *, env, deps, gather_data, i, save): + """ + Run the command with the given env after waiting for deps. + """ + for task in deps: + await task + if gather_data: + t1 = time.monotonic() + proc = await asyncio.create_subprocess_shell( + command, + env=env, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + stdout, stderr = await proc.communicate() + code = proc.returncode + results = {'exit_code': code, 'stdout': stdout, 'stderr': stderr} + if gather_data: + t2 = time.monotonic() + results['time'] = t2 - t1 + sizes = {} + for tmp_file in files_mentioned(command): + if os.path.exists(tmp_file): + sizes[tmp_file] = os.path.getsize(tmp_file) + else: + sizes[tmp_file] = 0 + results['files'] = sizes + if save: + dest = pathlib.Path(save) / str(i) + dest.mkdir() + for src in map(pathlib.Path, files_mentioned(command)): + if src.exists(): + shutil.copy2(src, dest / (src.name)) + return results + + +async def run_graph(*, env, commands, graph, gather_data, save): + """ + Return outputs/errors (and optionally time/file info) from commands. + """ + tasks = [] + for i, (command, indices) in enumerate(zip(commands, graph)): + deps = {tasks[j] for j in indices} + tasks.append(asyncio.create_task(run_command( + command, + env=env, + deps=deps, + gather_data=gather_data, + i=i, + save=save, + ))) + return [await task for task in tasks] + + +def print_command_outputs(command_results): + """ + Print captured stdout and stderr from commands. + """ + for result in command_results: + sys.stdout.write(result['stdout'].decode('ascii')) + sys.stderr.write(result['stderr'].decode('ascii')) + + +def write_log_csv(command_parts, command_results, *, filename): + """ + Write a CSV file of the times and /tmp file sizes from each command. + """ + tmp_files = [] + for result in command_results: + tmp_files.extend(result['files'].keys()) + with open(filename, 'w', newline='') as csvfile: + fieldnames = ['command', 'seconds'] + list(dict.fromkeys(tmp_files)) + writer = csv.DictWriter(csvfile, fieldnames=fieldnames) + writer.writeheader() + for i, result in enumerate(command_results): + command = f'{i} {os.path.basename(command_parts[i][0])}' + row = {'command': command, 'seconds': result['time']} + writer.writerow({**row, **result['files']}) + + +def exit_code(results): + """ + Aggregate individual exit codes into a single code. + """ + for result in results: + code = result['exit_code'] + if code != 0: + return code + return 0 + + +def fast_nvcc(args, *, config=default_config): + """ + Emulate the result of calling the given nvcc binary with args. + + Should run faster than plain nvcc. + """ + warn_if_windows() + warn_if_tmpdir_flag(args) + dryrun_data = nvcc_dryrun_data(config.nvcc, args) + env = dryrun_data['env'] + warn_if_tmpdir_set(env) + commands = dryrun_data['commands'] + if not config.faithful: + commands = make_rm_force(unique_module_id_files(commands)) + command_parts = list(map(shlex.split, commands)) + if config.verbose: + print_verbose_output( + env=env, + commands=command_parts, + filename=config.verbose, + ) + graph = nvcc_data_dependencies(commands) + warn_if_not_weakly_connected(graph) + if config.graph: + print_dot_graph( + commands=command_parts, + graph=graph, + filename=config.graph, + ) + if config.sequential: + graph = straight_line_dependencies(commands) + results = asyncio.run(run_graph( + env=env, + commands=commands, + graph=graph, + gather_data=bool(config.table), + save=config.save, + )) + print_command_outputs(results) + if config.table: + write_log_csv(command_parts, results, filename=config.table) + return exit_code([dryrun_data] + results) + + +def our_arg(arg): + return arg != '--' + + +if __name__ == '__main__': + argv = sys.argv[1:] + us = list(itertools.takewhile(our_arg, argv)) + them = list(itertools.dropwhile(our_arg, argv)) + sys.exit(fast_nvcc(them[1:], config=parser.parse_args(us))) From dcd1e3d78d3163aec75a2eb1aedb4241e01a9c78 Mon Sep 17 00:00:00 2001 From: generatedunixname89002005325676 Date: Fri, 11 Dec 2020 08:39:25 -0800 Subject: [PATCH 03/33] [AutoAccept][Codemod][FBSourceClangFormatLinter] Daily `arc lint --take CLANGFORMAT` Reviewed By: zertosh Differential Revision: D25490983 fbshipit-source-id: b24a11214a485a4a24ccf7da1e72715b450d3a81 --- test/cpp/tensorexpr/test_kernel.cpp | 2 +- torch/csrc/jit/tensorexpr/kernel.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/cpp/tensorexpr/test_kernel.cpp b/test/cpp/tensorexpr/test_kernel.cpp index 895b025ac4e0..cf658ad488f6 100644 --- a/test/cpp/tensorexpr/test_kernel.cpp +++ b/test/cpp/tensorexpr/test_kernel.cpp @@ -522,7 +522,7 @@ TEST(Kernel, DISABLED_SumAllAxes) { std::string li_to_str(at::ArrayRef li) { std::stringstream out; bool first = true; - for (auto elem: li) { + for (auto elem : li) { if (!first) { out << ", "; } diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp index c4228ae955b6..88cf5761cfa1 100644 --- a/torch/csrc/jit/tensorexpr/kernel.cpp +++ b/torch/csrc/jit/tensorexpr/kernel.cpp @@ -1,5 +1,6 @@ #include +#include #include #include #include @@ -8,7 +9,6 @@ #include #include #include -#include using namespace torch::jit; using namespace torch::jit::tensorexpr; From f204f77e6d0933b274c5ebc5c8f7ce1e1ee1c2fd Mon Sep 17 00:00:00 2001 From: Luca Wehrstedt Date: Fri, 11 Dec 2020 09:23:09 -0800 Subject: [PATCH 04/33] Drop FutureNCCL in favor of vanilla CUDAFuture (#49014) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49014 We extracted a generic and reusable CUDAFuture class from FutureNCCL, but we had left FutureNCCL around, as a subclass of CUDAFuture, in order to deal with some peculiarity of ProcessGroupNCCL, namely that the future would be completed right away when constructed and that its CUDA events would be _shared_ with the ones of the WorkNCCL. This required some "hacks" in CUDAFuture itself (protected members, fields wrapped in shared_ptrs, ...). My understanding is that creating CUDA events is a rather cheap operation. That would mean that we could afford to record _twice_ the events after each NCCL call, once for the WorkNCCL and once for the future. By doing so, we can use the CUDAFuture class directly and revert all its hacks. ghstack-source-id: 118391217 Test Plan: Unit tests Reviewed By: mrshenli Differential Revision: D25355272 fbshipit-source-id: 3a2a0891724928221ff0f08600675d2f5990e674 --- aten/src/ATen/cuda/CUDAFuture.h | 22 +++------- torch/csrc/distributed/c10d/init.cpp | 15 +++---- torch/lib/c10d/ProcessGroupNCCL.cpp | 22 ++++++++-- torch/lib/c10d/ProcessGroupNCCL.hpp | 64 +++------------------------- 4 files changed, 37 insertions(+), 86 deletions(-) diff --git a/aten/src/ATen/cuda/CUDAFuture.h b/aten/src/ATen/cuda/CUDAFuture.h index 4334101478f1..ae43fb2a2dd6 100644 --- a/aten/src/ATen/cuda/CUDAFuture.h +++ b/aten/src/ATen/cuda/CUDAFuture.h @@ -21,7 +21,7 @@ namespace at { namespace cuda { -struct TORCH_CUDA_API CUDAFuture : at::ivalue::Future { +struct TORCH_CUDA_API CUDAFuture final : at::ivalue::Future { public: using at::ivalue::Future::Future; @@ -56,12 +56,11 @@ struct TORCH_CUDA_API CUDAFuture : at::ivalue::Future { } } - cudaEvents_ = std::make_shared>(); for (c10::DeviceIndex idx = 0; idx < isCudaDeviceUsed.size(); idx++) { if (isCudaDeviceUsed[idx]) { at::cuda::CUDAEvent cudaEvent; cudaEvent.record(at::cuda::getCurrentCUDAStream(idx)); - (*cudaEvents_).push_back(std::move(cudaEvent)); + cudaEvents_.push_back(std::move(cudaEvent)); } } } @@ -78,7 +77,7 @@ struct TORCH_CUDA_API CUDAFuture : at::ivalue::Future { // misbehaving this also ends up using memory on those devices, which the // user might not want. std::vector streams; - for (at::cuda::CUDAEvent& cudaEvent : *cudaEvents_) { + for (at::cuda::CUDAEvent& cudaEvent : cudaEvents_) { c10::DeviceIndex idx = cudaEvent.device_index(); // FIXME Should we find a way to allow to change the priority of // streams? @@ -107,7 +106,7 @@ struct TORCH_CUDA_API CUDAFuture : at::ivalue::Future { } void postWaitHook(const at::IValue& value) override { - for (at::cuda::CUDAEvent& cudaEvent : *cudaEvents_) { + for (at::cuda::CUDAEvent& cudaEvent : cudaEvents_) { cudaEvent.block( at::cuda::getCurrentCUDAStream(cudaEvent.device_index())); } @@ -120,12 +119,7 @@ struct TORCH_CUDA_API CUDAFuture : at::ivalue::Future { } } - // FIXME This field is protected (rather than private) and wrapped in a - // shared_ptr in order to support the FutureNCCL subclass, which wants to set - // the events on its own in order to use the same ones as its WorkNCCL class. - // Once WorkNCCL is gone (as part of the Future and Work merge) this should be - // fixed. - protected: + private: // The device that was current when markCompleted was called, which we'll // restore when invoking callbacks. c10::DeviceIndex currentDevice_; @@ -134,19 +128,15 @@ struct TORCH_CUDA_API CUDAFuture : at::ivalue::Future { // are recorded on the appropriate streams when the future is marked completed // and can then be queried/waited/blocked on. There is one event for each // distinct device on which the value's tensors reside. - std::shared_ptr> cudaEvents_; + std::vector cudaEvents_; // A cached version of the data ptrs extracted from the value when the future // is first marked completed. std::vector> dataPtrs_; - private: DataPtrExtractor dataPtrExtractor_; std::mutex dataPtrExtractorMutex_; - // FIXME This too is protected so that it can be used by FutureNCCL. Please - // undo that once FutureNCCL is dropped in favor of a "vanilla" CUDAFuture. - protected: std::vector> extractDataPtrs( const at::IValue& value) { std::unique_lock lock(dataPtrExtractorMutex_); diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 54fc33e54424..0a7daa3a5b94 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -1127,24 +1127,23 @@ that adds a prefix to each key inserted to the store. >>> ddp_model._egister_comm_hook(state = None, hook = allreduce) .. warning :: - ``get_future`` API supports only NCCL backend and single-process single-device mode. + ``get_future`` API supports only NCCL backend. The ``torch._C.Future`` object returned by this API can be used in - ``DistributedDataParallel.register_comm_hook``, but it is subject to some subtle - differences compared to ``torch.futures.Future`` due to compromises made for performance - reasons. + ``DistributedDataParallel.register_comm_hook``, and adds some CUDA-specific + features on top of ``torch.futures.Future``. In the example above, ``allreduce`` work will be done on GPU using NCCL backend, ``fut.wait()`` will return after synchronizing the appropriate NCCL streams - with PyTorch's default device streams to ensure we can have asynchronous CUDA + with PyTorch's current device streams to ensure we can have asynchronous CUDA execution and it does not wait for the entire operation to complete on GPU. Note that - ``FutureNCCL`` does not support ``NCCL_BLOCKING_WAIT`` flag or NCCL's ``barrier()``. + ``CUDAFuture`` does not support ``NCCL_BLOCKING_WAIT`` flag or NCCL's ``barrier()``. In addition, if a callback function was added by ``fut.then()``, it will wait until ``WorkNCCL``'s NCCL streams synchronize with ``ProcessGroupNCCL``'s dedicated callback stream and invoke the callback inline after running the callback on the callback stream. - ``fut.then()`` will return another ``FutureNCCL`` that holds the return value of the + ``fut.then()`` will return another ``CUDAFuture`` that holds the return value of the callback and a ``CUDAEvent`` that recorded the callback stream. - Note that ``fut.done()`` returns if the enire operation is completed on the GPU. + Note that ``fut.done()`` returns only whether the operation has been enqueued on the GPU. )"); module.def( diff --git a/torch/lib/c10d/ProcessGroupNCCL.cpp b/torch/lib/c10d/ProcessGroupNCCL.cpp index 5152ce01e25e..19085f155020 100644 --- a/torch/lib/c10d/ProcessGroupNCCL.cpp +++ b/torch/lib/c10d/ProcessGroupNCCL.cpp @@ -1008,7 +1008,7 @@ std::vector ProcessGroupNCCL::WorkNCCL::result() { c10::intrusive_ptr ProcessGroupNCCL::WorkNCCL:: getFuture() { - return c10::make_intrusive(at::IValue(*outputs_), cudaEvents_); + return future_; } void ProcessGroupNCCL::workEnqueue( @@ -1046,7 +1046,7 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( bool can_profile = outputs.size() == 1; auto work = initWork(devices, rank_, opType, can_profile ? profilingTitle : nullptr); - // Store references to outputs to be used by WorkNCCL::getFuture. + // Store references to outputs to be used by WorkNCCL::result and operator<<. work->outputs_ = std::make_shared>(outputs); at::cuda::OptionalCUDAGuard gpuGuard; @@ -1088,6 +1088,13 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( work->ncclComms_[i] = ncclComms[i]; } + { + at::cuda::CUDAMultiStreamGuard streamGuard(ncclStreams_[key]); + work->future_ = c10::make_intrusive( + c10::ListType::create(c10::TensorType::get())); + work->future_->markCompleted(at::IValue(*work->outputs_)); + } + // Set appropriate work parameters. work->blockingWait_ = blockingWait_; work->opTimeout_ = opTimeout_; @@ -1097,7 +1104,7 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( // recordFunctionEndCallback_ is normally called in fininsh() function by // base class, but since finish is not called by WorkNCCL, we schedule this // function to be run when work is done. Note that addCallback() onto the - // Work's futureNCCL is not useful here, as it would just run the callback + // Work's CUDAFuture is not useful here, as it would just run the callback // inline. // Note when can_profile is false, profilingTitle is not provided and so, // recordFunctionEndCallback_ is not set. @@ -1132,7 +1139,7 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( auto work = initWork(devices, rank_, opType); if (opType == OpType::RECV) { - // Store references to outputs to be used by WorkNCCL::getFuture. + // Store references to outputs to be used by WorkNCCL::result and operator<<. work->outputs_ = std::make_shared>(tensors); } @@ -1178,6 +1185,13 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( work->store_ = store_; } + if (opType == OpType::RECV) { + at::cuda::CUDAMultiStreamGuard streamGuard(ncclStreams_[key]); + work->future_ = c10::make_intrusive( + c10::ListType::create(c10::TensorType::get())); + work->future_->markCompleted(at::IValue(*work->outputs_)); + } + return work; } diff --git a/torch/lib/c10d/ProcessGroupNCCL.hpp b/torch/lib/c10d/ProcessGroupNCCL.hpp index adbfec445549..4d9dc3bd1ae8 100644 --- a/torch/lib/c10d/ProcessGroupNCCL.hpp +++ b/torch/lib/c10d/ProcessGroupNCCL.hpp @@ -110,7 +110,7 @@ class ProcessGroupNCCL : public ProcessGroup { bool finishedGPUExecution(); // Get a Future object that will be marked as completed internally. - // It actually returns a FutureNCCL object which is a sub class Future. + // It actually returns a CUDAFuture object which is a sub class of Future. c10::intrusive_ptr getFuture() override; // Helper function that sets an exception_ptr on the WorkNCCL object. @@ -170,9 +170,13 @@ class ProcessGroupNCCL : public ProcessGroup { // to the store. c10::intrusive_ptr store_; - // Store a reference to NCCL collective's outputs to be used by getFuture. + // Store a reference to NCCL collective's outputs, used by result and to + // give a more descriptive message when representing the Work as a string. std::shared_ptr> outputs_; + // The future returned by getFuture. + c10::intrusive_ptr future_; + friend class ProcessGroupNCCL; }; @@ -190,62 +194,6 @@ class ProcessGroupNCCL : public ProcessGroup { bool isHighPriorityStream; }; - // FutureNCCL is a subclass of ivalue's Future. The goal is to use - // this class in getFuture API of WorkNCCL. This Future is mostly a - // wrapper to synchronize streams appropriately and it mostly enables - // the async programming model of CUDA while trying to adhere to the - // Future interface. FutureNCCL does not support NCCL_BLOCKING_WAIT flag - // or NCCL's barrier(). - // - // If created by WorkNCCL's getFuture API, FutureNCCL has a reference to - // WorkNCCL's cudaEvents, NCCL collective's outputs, and the device indices of - // outputs' devices. Its value is NCCL collective's outputs. - // - // If created by FutureNCCL's then callback, its value becomes the value of - // callback() and its cudaEvents will record the NCCL stream that runs that - // callback. Before invoking the callback, FutureNCCL will synchronize its - // own cudaEvents with the stream that runs the callback. This design - // enables synchronizing the appropriate streams and avoids stalling PyTorch's - // default stream while running the callback. In case of multiple then - // callbacks, each will be executed on its own fresh stream. - struct FutureNCCL : at::cuda::CUDAFuture { - public: - FutureNCCL( - at::IValue value, - std::shared_ptr> cudaEvents) - : at::cuda::CUDAFuture(c10::ListType::create(c10::TensorType::get())){ - // Check that the device indices are distinct - std::unordered_set uniqueDeviceIndices; - for (const at::cuda::CUDAEvent& event : *cudaEvents) { - TORCH_INTERNAL_ASSERT(event.isCreated()); - uniqueDeviceIndices.insert(event.device_index()); - } - TORCH_INTERNAL_ASSERT( - cudaEvents->size() == uniqueDeviceIndices.size(), - "Got ", cudaEvents->size(), " events, but only ", - uniqueDeviceIndices.size(), " distinct devices"); - auto dataPtrs = extractDataPtrs(value); - for (const at::DataPtr& data_ptr : dataPtrs) { - TORCH_INTERNAL_ASSERT( - std::find_if( - cudaEvents->begin(), - cudaEvents->end(), - [&](const at::cuda::CUDAEvent& ev) { - return ev.device_index() == data_ptr.device().index(); - }) != cudaEvents->end()); - } - currentDevice_ = c10::cuda::current_device(); - cudaEvents_ = std::move(cudaEvents); - dataPtrs_ = std::move(dataPtrs); - markCompleted(std::move(value)); - } - - protected: - void postMarkCompletedHook(const at::IValue& value) override { - // Do nothing because the constructor already stored the events. - } - }; - // If you wish to create multiple process groups, each with a potentially // different rank and size, you can do so by passing a new store instance // to each one. If you have only a single store object, you can From f10b53d9eaa282d240fa79ee8f03ea42457803d5 Mon Sep 17 00:00:00 2001 From: Dhruv Matani Date: Fri, 11 Dec 2020 09:38:53 -0800 Subject: [PATCH 05/33] [PyTorch Mobile] Record dtypes for tensors used in kernel function implementations (#48826) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/48826 This change updates various macros to pass in the kernel tag string (`const char*`) into the macro that sets up the `case` statement for the dtype switch. This macro already receives the dtype (enum) which we also need. There are 2 phases we need to build out for the `dtype` tracing to work: 1. Recording Phase 2. Conditional Compilation Phase For this most part, this change is trying to focus on [1] (The Recording Phase) and sets up a new `RecordScope` enum value to track kernel dtypes. This code is compiled in only if a specific macro is defined (since this is an **extremely** hot code path, and even the slightest regression here can cause tremendous slow down overall). I have only added a skeleton of the phase [2] (Conditional Compilation Phase) and there is a no-op `constexpr` method that selects every dtype in the kernel implementation. In subsequent diffs, this will be updated to point to a code-generated function based on the result of tracing the models that were requested. ghstack-source-id: 118336675 Test Plan: See the next few diff in the stack for the application of this change to both record triggered dtypes (in kernel functions) as well as select dtype specific portions of kernel functions. Reviewed By: ezyang Differential Revision: D24220926 fbshipit-source-id: d7dbf21c7dcc6ce981d0fd4dcb62ca829fe3f69d --- aten/src/ATen/Dispatch.h | 451 ++++++++++++------ aten/src/ATen/native/cpu/SortingKernel.cpp | 8 +- .../ATen/native/cuda/ScatterGatherKernel.cu | 8 +- aten/src/ATen/native/cuda/TriangularOps.cu | 2 +- aten/src/ATen/record_function.h | 2 + 5 files changed, 305 insertions(+), 166 deletions(-) diff --git a/aten/src/ATen/Dispatch.h b/aten/src/ATen/Dispatch.h index 9f0c51166172..41252609953f 100644 --- a/aten/src/ATen/Dispatch.h +++ b/aten/src/ATen/Dispatch.h @@ -2,17 +2,59 @@ #include #include +#include #include #include #include +#include #include +#include + +namespace at { +/** + * The method should_include_kernel_dtype() returns true/false + * based on whether the switching code for a specific dtype should be + * included based on build time constants generated from tracing model + * execution. This method will be implmeneted via code-generation and + * included in this file when code-gen is ready. + */ +inline constexpr bool should_include_kernel_dtype( + const char *kernel_tag_str, + at::ScalarType scalar_type +) { + return true; +} +} + +/** + * In the Facebook internal build (using BUCK), this macro is enabled by + * passing in -c pt.enable_record_kernel_dtype=1 when building the tracer + * binary. + */ +#if defined ENABLE_RECORD_KERNEL_FUNCTION_DTYPE +#define RECORD_KERNEL_FUNCTION_DTYPE(NAME, enum_type) \ + {RECORD_FUNCTION_WITH_SCOPE( \ + at::RecordScope::KERNEL_FUNCTION_DTYPE, \ + std::string(NAME) + "$" + toString(enum_type), \ + {});} +#else +#define RECORD_KERNEL_FUNCTION_DTYPE(NAME, enum_type) +#endif -#define AT_PRIVATE_CASE_TYPE(enum_type, type, ...) \ - case enum_type: { \ - using scalar_t = type; \ - return __VA_ARGS__(); \ +#define AT_PRIVATE_CASE_TYPE_USING_HINT(NAME, enum_type, type, HINT, ...) \ + case enum_type: { \ + at::guts::if_constexpr<(!at::should_include_kernel_dtype(NAME, enum_type))>( \ + [&] { \ + AT_ERROR("dtype '", toString(enum_type), "' not selected for kernel tag ", #NAME); \ + } \ + ); \ + using HINT = type; \ + return __VA_ARGS__(); \ } +#define AT_PRIVATE_CASE_TYPE(NAME, enum_type, type, ...) \ + AT_PRIVATE_CASE_TYPE_USING_HINT(NAME, enum_type, type, scalar_t, __VA_ARGS__) + // Workaround for C10_UNUSED because CUDA 10.1 and below fails to handle unused // attribute in the type aliasing context. Keep name long and verbose to avoid // macro collisions. @@ -143,6 +185,21 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {} // 4. Should complex be supported? The answer is almost always no, // unless you are working on "generic" code that should work on // all dtypes. +// +// Parameters: +// ----------- +// +// 1. The NAME argument is a "tag" that is used to trace and then +// conditionally compile fragments of the case statements such +// that the kernel functions are specialized only for the dtypes +// that are needed. The NAME parameter *must* be a build time +// cons char* (can't be std::string, etc...) +// +// Please ensure that the NAME is unique for every implementation +// or you run the risk of over-including code for the kernel +// functions. There is no risk of missing out on any code, so +// it's mostly a risk of a Type-2 error, and not a Type-1 error. +// // NB: the the_type variable is not used, but we have kept it for // backwards compatibility. It's probably not used by anyone though; @@ -154,26 +211,28 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {} const auto& the_type = TYPE; \ /* don't use TYPE again in case it is an expensive or side-effect op */ \ at::ScalarType _st = ::detail::scalar_type(the_type); \ + RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ switch (_st) { \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \ default: \ AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ } \ }() -#define AT_DISPATCH_FLOATING_TYPES_AND_HALF(TYPE, NAME, ...) \ - [&] { \ - const auto& the_type = TYPE; \ - /* don't use TYPE again in case it is an expensive or side-effect op */ \ - at::ScalarType _st = ::detail::scalar_type(the_type); \ - switch (_st) { \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Half, at::Half, __VA_ARGS__) \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ - } \ +#define AT_DISPATCH_FLOATING_TYPES_AND_HALF(TYPE, NAME, ...) \ + [&] { \ + const auto& the_type = TYPE; \ + /* don't use TYPE again in case it is an expensive or side-effect op */ \ + at::ScalarType _st = ::detail::scalar_type(the_type); \ + RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ + switch (_st) { \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Half, at::Half, __VA_ARGS__) \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ + } \ }() #define AT_DISPATCH_FLOATING_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \ @@ -181,10 +240,11 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {} const auto& the_type = TYPE; \ /* don't use TYPE again in case it is an expensive or side-effect op */ \ at::ScalarType _st = ::detail::scalar_type(the_type); \ + RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ switch (_st) { \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, \ SCALARTYPE, \ decltype(c10::impl::ScalarTypeToCPPType::t), \ __VA_ARGS__) \ @@ -199,14 +259,17 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {} const auto& the_type = TYPE; \ /* don't use TYPE again in case it is an expensive or side-effect op */ \ at::ScalarType _st = ::detail::scalar_type(the_type); \ + RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ switch (_st) { \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ + NAME, \ SCALARTYPE1, \ decltype(c10::impl::ScalarTypeToCPPType::t), \ __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ + NAME, \ SCALARTYPE2, \ decltype(c10::impl::ScalarTypeToCPPType::t), \ __VA_ARGS__) \ @@ -220,13 +283,20 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {} const auto& the_type = TYPE; \ /* don't use TYPE again in case it is an expensive or side-effect op */ \ at::ScalarType _st = ::detail::scalar_type(the_type); \ + RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ switch (_st) { \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ - at::ScalarType::ComplexDouble, c10::complex, __VA_ARGS__) \ + NAME, \ + at::ScalarType::ComplexDouble, \ + c10::complex, \ + __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ - at::ScalarType::ComplexFloat, c10::complex, __VA_ARGS__) \ + NAME, \ + at::ScalarType::ComplexFloat, \ + c10::complex, \ + __VA_ARGS__) \ default: \ AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ } \ @@ -238,14 +308,18 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {} const auto& the_type = TYPE; \ /* don't use TYPE again in case it is an expensive or side-effect op */ \ at::ScalarType _st = ::detail::scalar_type(the_type); \ + RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ switch (_st) { \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ + NAME, \ at::ScalarType::ComplexDouble, c10::complex, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ + NAME, \ at::ScalarType::ComplexFloat, c10::complex, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ + NAME, \ SCALARTYPE, \ decltype(c10::impl::ScalarTypeToCPPType::t), \ __VA_ARGS__) \ @@ -259,19 +333,28 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {} [&] { \ const auto& the_type = TYPE; \ /* don't use TYPE again in case it is an expensive or side-effect op */ \ - at::ScalarType _st = ::detail::scalar_type(the_type); \ - switch (_st) { \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - at::ScalarType::ComplexDouble, c10::complex, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - at::ScalarType::ComplexFloat, c10::complex, __VA_ARGS__) \ + at::ScalarType _st = ::detail::scalar_type(the_type); \ + RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ + switch (_st) { \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE( \ + NAME, \ + at::ScalarType::ComplexDouble, \ + c10::complex, \ + __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE( \ + NAME, \ + at::ScalarType::ComplexFloat, \ + c10::complex, \ + __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ + NAME, \ SCALARTYPE1, \ decltype(c10::impl::ScalarTypeToCPPType::t), \ __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ + NAME, \ SCALARTYPE2, \ decltype(c10::impl::ScalarTypeToCPPType::t), \ __VA_ARGS__) \ @@ -285,31 +368,36 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {} const auto& the_type = TYPE; \ /* don't use TYPE again in case it is an expensive or side-effect op */ \ at::ScalarType _st = ::detail::scalar_type(the_type); \ + RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ switch (_st) { \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Char, int8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Int, int32_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Long, int64_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Short, int16_t, __VA_ARGS__) \ default: \ AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ } \ }() -#define AT_DISPATCH_INTEGRAL_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \ - [&] { \ - switch (TYPE) { \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ +#define AT_DISPATCH_INTEGRAL_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \ + [&] { \ + const auto& the_type = TYPE; \ + /* don't use TYPE again in case it is an expensive or side-effect op */ \ + at::ScalarType _st = ::detail::scalar_type(the_type); \ + RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ + switch (_st) { \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Char, int8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Int, int32_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Long, int64_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Short, int16_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, \ SCALARTYPE, \ decltype(c10::impl::ScalarTypeToCPPType::t), \ __VA_ARGS__) \ default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ } \ }() @@ -318,17 +406,18 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {} const auto& the_type = TYPE; \ /* don't use TYPE again in case it is an expensive or side-effect op */ \ at::ScalarType _st = ::detail::scalar_type(the_type); \ - switch (_st) { \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ - } \ + RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ + switch (_st) { \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Char, int8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Int, int32_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Long, int64_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Short, int16_t, __VA_ARGS__) \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ + } \ }() #define AT_DISPATCH_COMPLEX_TYPES(TYPE, NAME, ...) \ @@ -336,11 +425,18 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {} const auto& the_type = TYPE; \ /* don't use TYPE again in case it is an expensive or side-effect op */ \ at::ScalarType _st = ::detail::scalar_type(the_type); \ + RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ switch (_st) { \ AT_PRIVATE_CASE_TYPE( \ - at::ScalarType::ComplexFloat, c10::complex, __VA_ARGS__) \ + NAME, \ + at::ScalarType::ComplexFloat, \ + c10::complex, \ + __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ - at::ScalarType::ComplexDouble, c10::complex, __VA_ARGS__) \ + NAME, \ + at::ScalarType::ComplexDouble, \ + c10::complex, \ + __VA_ARGS__) \ default: \ AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ } \ @@ -351,6 +447,7 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {} const auto& the_type = TYPE; \ /* don't use TYPE again in case it is an expensive or side-effect op */ \ at::ScalarType _st = ::detail::scalar_type(the_type); \ + RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ switch (_st) { \ AT_QINT_PRIVATE_CASE_TYPE( \ at::kQInt8, at::qint8, at::kChar, int8_t, __VA_ARGS__) \ @@ -368,6 +465,7 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {} const auto& the_type = TYPE; \ /* don't use TYPE again in case it is an expensive or side-effect op */ \ at::ScalarType _st = ::detail::scalar_type(the_type); \ + RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ switch (_st) { \ AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \ at::kQInt8, at::qint8, int8_t, CHAR_BIT, SCHAR_MIN, SCHAR_MAX, __VA_ARGS__) \ @@ -387,17 +485,18 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {} const auto& the_type = TYPE; \ /* don't use TYPE again in case it is an expensive or side-effect op*/ \ at::ScalarType _st = ::detail::scalar_type(the_type); \ + RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ switch (_st) { \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Char, int8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Int, int32_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Long, int64_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Short, int16_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, \ at::ScalarType::ComplexFloat, c10::complex, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ + AT_PRIVATE_CASE_TYPE(NAME, \ at::ScalarType::ComplexDouble, c10::complex, __VA_ARGS__) \ default: \ AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ @@ -406,154 +505,196 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {} #define AT_DISPATCH_ALL_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \ [&] { \ - switch (TYPE) { \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \ + const auto& the_type = TYPE; \ + /* don't use TYPE again in case it is an expensive or side-effect op*/ \ + at::ScalarType _st = ::detail::scalar_type(the_type); \ + RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ + switch (_st) { \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Char, int8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Int, int32_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Long, int64_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Short, int16_t, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ + NAME, \ SCALARTYPE, \ decltype(c10::impl::ScalarTypeToCPPType::t), \ __VA_ARGS__) \ default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ } \ }() #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE, TYPE, NAME, ...) \ [&] { \ - switch (TYPE) { \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \ + const auto& the_type = TYPE; \ + /* don't use TYPE again in case it is an expensive or side-effect op*/ \ + at::ScalarType _st = ::detail::scalar_type(the_type); \ + RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ + switch (_st) { \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Char, int8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Int, int32_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Long, int64_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Short, int16_t, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ - at::ScalarType::ComplexFloat, c10::complex, __VA_ARGS__) \ + NAME, \ + at::ScalarType::ComplexFloat, \ + c10::complex, \ + __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ - at::ScalarType::ComplexDouble, c10::complex, __VA_ARGS__) \ + NAME, \ + at::ScalarType::ComplexDouble, \ + c10::complex, \ + __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ + NAME, \ SCALARTYPE, \ decltype(c10::impl::ScalarTypeToCPPType::t), \ __VA_ARGS__) \ default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ } \ }() #define AT_DISPATCH_ALL_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \ [&] { \ - switch (TYPE) { \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \ + const auto& the_type = TYPE; \ + /* don't use TYPE again in case it is an expensive or side-effect op*/ \ + at::ScalarType _st = ::detail::scalar_type(the_type); \ + RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ + switch (_st) { \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Char, int8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Int, int32_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Long, int64_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Short, int16_t, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ + NAME, \ SCALARTYPE1, \ decltype(c10::impl::ScalarTypeToCPPType::t), \ __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ + NAME, \ SCALARTYPE2, \ decltype(c10::impl::ScalarTypeToCPPType::t), \ __VA_ARGS__) \ default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ } \ }() #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( \ SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \ [&] { \ - switch (TYPE) { \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - at::ScalarType::ComplexFloat, c10::complex, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - at::ScalarType::ComplexDouble, c10::complex, __VA_ARGS__) \ + const auto& the_type = TYPE; \ + /* don't use TYPE again in case it is an expensive or side-effect op*/ \ + at::ScalarType _st = ::detail::scalar_type(the_type); \ + RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ + switch (_st) { \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Char, int8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Int, int32_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Long, int64_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Short, int16_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE( \ + NAME, at::ScalarType::ComplexFloat, c10::complex, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE( \ + NAME, at::ScalarType::ComplexDouble, c10::complex, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ + NAME, \ SCALARTYPE1, \ decltype(c10::impl::ScalarTypeToCPPType::t), \ __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ + NAME, \ SCALARTYPE2, \ decltype(c10::impl::ScalarTypeToCPPType::t), \ __VA_ARGS__) \ default: \ - AT_ERROR(#NAME, " not implemented for '", TYPE, "'"); \ + AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ } \ }() -#define AT_DISPATCH_ALL_TYPES_AND3( \ - SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \ - [&] { \ - switch (TYPE) { \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \ +#define AT_DISPATCH_ALL_TYPES_AND3( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \ + [&] { \ + const auto& the_type = TYPE; \ + /* don't use TYPE again in case it is an expensive or side-effect op*/ \ + at::ScalarType _st = ::detail::scalar_type(the_type); \ + RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ + switch (_st) { \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Char, int8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Int, int32_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Long, int64_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Short, int16_t, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ + NAME, \ SCALARTYPE1, \ decltype(c10::impl::ScalarTypeToCPPType::t), \ __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ + NAME, \ SCALARTYPE2, \ decltype(c10::impl::ScalarTypeToCPPType::t), \ __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ + NAME, \ SCALARTYPE3, \ decltype(c10::impl::ScalarTypeToCPPType::t), \ __VA_ARGS__) \ default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ } \ }() #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( \ SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \ [&] { \ - switch (TYPE) { \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - at::ScalarType::ComplexFloat, c10::complex, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - at::ScalarType::ComplexDouble, c10::complex, __VA_ARGS__) \ + const auto& the_type = TYPE; \ + /* don't use TYPE again in case it is an expensive or side-effect op*/ \ + at::ScalarType _st = ::detail::scalar_type(the_type); \ + RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ + switch (_st) { \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Char, int8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Int, int32_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Long, int64_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Short, int16_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE( \ + NAME, at::ScalarType::ComplexFloat, c10::complex, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE( \ + NAME, at::ScalarType::ComplexDouble, c10::complex, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ + NAME, \ SCALARTYPE1, \ decltype(c10::impl::ScalarTypeToCPPType::t), \ __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ + NAME, \ SCALARTYPE2, \ decltype(c10::impl::ScalarTypeToCPPType::t), \ __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ + NAME, \ SCALARTYPE3, \ decltype(c10::impl::ScalarTypeToCPPType::t), \ __VA_ARGS__) \ default: \ - AT_ERROR(#NAME, " not implemented for '", TYPE, "'"); \ + AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ } \ }() @@ -562,15 +703,10 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {} const auto& the_index_type = TYPE; \ /* don't use TYPE again in case it is an expensive or side-effect op */ \ at::ScalarType _it = ::detail::scalar_type(the_index_type); \ + RECORD_KERNEL_FUNCTION_DTYPE(NAME, _it) \ switch (_it) { \ - case at::ScalarType::Int: { \ - using index_t = int32_t; \ - return __VA_ARGS__(); \ - } \ - case at::ScalarType::Long: { \ - using index_t = int64_t; \ - return __VA_ARGS__(); \ - } \ + AT_PRIVATE_CASE_TYPE_USING_HINT(NAME, at::ScalarType::Int, int32_t, index_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE_USING_HINT(NAME, at::ScalarType::Long, int64_t, index_t, __VA_ARGS__)\ default: \ AT_ERROR(#NAME, " not implemented for '", toString(_it), "'"); \ } \ @@ -586,15 +722,16 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {} const auto& the_type = TYPE; \ /* don't use TYPE again in case it is an expensive or side-effect op */ \ at::ScalarType _st = ::detail::scalar_type(the_type); \ + RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ switch (_st) { \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Half, at::Half, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Char, int8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Int, int32_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Long, int64_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Short, int16_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Half, at::Half, __VA_ARGS__) \ default: \ AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ } \ diff --git a/aten/src/ATen/native/cpu/SortingKernel.cpp b/aten/src/ATen/native/cpu/SortingKernel.cpp index 7d13de185509..1d69af7c5622 100644 --- a/aten/src/ATen/native/cpu/SortingKernel.cpp +++ b/aten/src/ATen/native/cpu/SortingKernel.cpp @@ -47,10 +47,10 @@ void _dim_apply( auto values_dim_stride = values.stride(dim); auto indices_dim_stride = indices.stride(dim); auto dim_size = values.size(dim); - + AT_DISPATCH_ALL_TYPES_AND2( ScalarType::Bool, ScalarType::Half, iter.dtype(), - method_name, [&] { + "sorting_kernel_method_name", [&] { auto loop = [&](char** data, const int64_t* strides, int64_t n) { auto* values_data_bytes = data[0]; auto* indices_data_bytes = data[1]; @@ -68,7 +68,7 @@ void _dim_apply( indices_data_bytes += strides[1]; } }; - + iter.for_each(loop); } ); @@ -114,7 +114,7 @@ static void sort_kernel( auto composite_accessor = CompositeRandomAccessorCPU< decltype(values_accessor), decltype(indices_accessor) >(values_accessor, indices_accessor); - + if (descending) { std::sort(composite_accessor, composite_accessor + dim_size, KeyValueCompDesc()); diff --git a/aten/src/ATen/native/cuda/ScatterGatherKernel.cu b/aten/src/ATen/native/cuda/ScatterGatherKernel.cu index 66ac81f5ecbf..ff3b5bb08baa 100644 --- a/aten/src/ATen/native/cuda/ScatterGatherKernel.cu +++ b/aten/src/ATen/native/cuda/ScatterGatherKernel.cu @@ -192,7 +192,7 @@ struct cuda_scatter_gather_base_kernel { AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, iter.dtype(), - method_name, [&] { + "cuda_scatter_gather_base_kernel_func", [&] { using dtype = typename std::conditional, scalar_t>::type; @@ -264,7 +264,7 @@ struct cuda_scatter_gather_base_kernel { AT_DISPATCH_FLOATING_TYPES_AND2( at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), - method_name, [&] { + "cuda_scatter_gather_base_kernel_reduce_multiply", [&] { using dtype = typename std::conditional, scalar_t>::type; @@ -365,7 +365,7 @@ struct cuda_scatter_fill_base_kernel { AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, iter.dtype(), - method_name, [&] { + "cuda_scatter_fill_base_kernel_func", [&] { using dtype = typename std::conditional, scalar_t>::type; @@ -417,7 +417,7 @@ struct cuda_scatter_fill_base_kernel { AT_DISPATCH_FLOATING_TYPES_AND2( at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), - method_name, [&] { + "cuda_scatter_fill_base_kernel_reduce_multiply", [&] { using dtype = typename std::conditional, scalar_t>::type; diff --git a/aten/src/ATen/native/cuda/TriangularOps.cu b/aten/src/ATen/native/cuda/TriangularOps.cu index 6ba73e1c143e..8d497b5c94af 100644 --- a/aten/src/ATen/native/cuda/TriangularOps.cu +++ b/aten/src/ATen/native/cuda/TriangularOps.cu @@ -60,7 +60,7 @@ Tensor& triu_tril_cuda_template(Tensor& result, const Tensor& self, int64_t k, c int64_t N = self.numel(); dim3 dim_block = cuda::getApplyBlock(); dim3 dim_grid((N + dim_block.x - 1) / dim_block.x); - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(at::ScalarType::Half, at::ScalarType::Bool, self.scalar_type(), name, [&]{ + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(at::ScalarType::Half, at::ScalarType::Bool, self.scalar_type(), "triu_tril_cuda_template", [&]{ if (cuda::detail::canUse32BitIndexMath(result) && cuda::detail::canUse32BitIndexMath(self)) { auto result_info = cuda::detail::getTensorInfo(result); auto self_info = cuda::detail::getTensorInfo(self); diff --git a/aten/src/ATen/record_function.h b/aten/src/ATen/record_function.h index 4b07d13aa747..43c2d878840d 100644 --- a/aten/src/ATen/record_function.h +++ b/aten/src/ATen/record_function.h @@ -23,6 +23,8 @@ enum class C10_API_ENUM RecordScope : uint8_t { BACKWARD_FUNCTION, // TorchScript functions, methods TORCHSCRIPT_FUNCTION, + // Kernel Function dtype Tag + KERNEL_FUNCTION_DTYPE, // User defined scope (e.g. with record_function()) USER_SCOPE, NUM_SCOPES, // must be the last in the list From c0a0845019f8951645cd2bd7fe5663d5cf552dfe Mon Sep 17 00:00:00 2001 From: Rohan Varma Date: Fri, 11 Dec 2020 10:25:47 -0800 Subject: [PATCH 06/33] Improve new_group example in the context of SyncBatchNorm (#48897) Summary: Closes https://github.com/pytorch/pytorch/issues/48804 Improves some documentation/example in SyncBN docs to clearly show that each rank must call into all `new_group()` calls for creating process subgroups, even if they are not going to be part of that particular subgroup. We then pick the right group, i.e. the group that the rank is part of, and pass that into the SyncBN APIs. Doc rendering: syncbn_update Pull Request resolved: https://github.com/pytorch/pytorch/pull/48897 Reviewed By: zou3519 Differential Revision: D25493181 Pulled By: rohan-varma fbshipit-source-id: a7e93fc8cc07ec7797e5dbc356f1c3877342cfa3 --- torch/nn/modules/batchnorm.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/torch/nn/modules/batchnorm.py b/torch/nn/modules/batchnorm.py index e76e307d36a6..48e58d637ea6 100644 --- a/torch/nn/modules/batchnorm.py +++ b/torch/nn/modules/batchnorm.py @@ -434,8 +434,14 @@ class SyncBatchNorm(_BatchNorm): >>> # With Learnable Parameters >>> m = nn.SyncBatchNorm(100) >>> # creating process group (optional) - >>> # process_ids is a list of int identifying rank ids. - >>> process_group = torch.distributed.new_group(process_ids) + >>> # ranks is a list of int identifying rank ids. + >>> ranks = list(range(8)) + >>> r1, r2 = ranks[:4], ranks[4:] + >>> # Note: every rank calls into new_group for every + >>> # process group created, even if that rank is not + >>> # part of the group. + >>> process_groups = [torch.distributed.new_group(pids) for pids in [r1, r2]] + >>> process_group = process_groups[0 if dist.get_rank() <= 3 else 1] >>> # Without Learnable Parameters >>> m = nn.BatchNorm3d(100, affine=False, process_group=process_group) >>> input = torch.randn(20, 100, 35, 45, 10) @@ -564,8 +570,14 @@ def convert_sync_batchnorm(cls, module, process_group=None): >>> torch.nn.BatchNorm1d(100), >>> ).cuda() >>> # creating process group (optional) - >>> # process_ids is a list of int identifying rank ids. - >>> process_group = torch.distributed.new_group(process_ids) + >>> # ranks is a list of int identifying rank ids. + >>> ranks = list(range(8)) + >>> r1, r2 = ranks[:4], ranks[4:] + >>> # Note: every rank calls into new_group for every + >>> # process group created, even if that rank is not + >>> # part of the group. + >>> process_groups = [torch.distributed.new_group(pids) for pids in [r1, r2]] + >>> process_group = process_groups[0 if dist.get_rank() <= 3 else 1] >>> sync_bn_module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module, process_group) """ From 42c78ed74525f03f6bb43110784f07c5a6ef1bef Mon Sep 17 00:00:00 2001 From: Tugsbayasgalan Manlaibaatar Date: Fri, 11 Dec 2020 10:58:54 -0800 Subject: [PATCH 07/33] Tuple Slice with both negative and positive stepped size (#48660) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/48660 We used to support tuple slicing without any step size before, but this PR extends this feature to support arbitrary step size. We do this by manually reconstructing a new tuple in the IR instead of relying on TupleSlice prim. Test Plan: python tests Imported from OSS Reviewed By: gmagogsfm Differential Revision: D25359336 fbshipit-source-id: 28cde536f28dd8a00607814b2900765e177f0ed7 --- test/test_jit.py | 26 +++++- tools/build_variables.bzl | 1 + torch/csrc/jit/frontend/ir_emitter.cpp | 87 ++++++++++++------- torch/csrc/jit/ir/ir.cpp | 30 ++++--- torch/csrc/jit/ir/ir.h | 6 +- .../csrc/jit/runtime/slice_indices_adjust.cpp | 56 ++++++++++++ torch/csrc/jit/runtime/slice_indices_adjust.h | 28 ++++++ 7 files changed, 186 insertions(+), 48 deletions(-) create mode 100644 torch/csrc/jit/runtime/slice_indices_adjust.cpp create mode 100644 torch/csrc/jit/runtime/slice_indices_adjust.h diff --git a/test/test_jit.py b/test/test_jit.py index 65b9c110f64f..239e4660674b 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -12128,10 +12128,10 @@ def tuple_slice(a): scripted_fn = torch.jit.script(tuple_slice) self.assertEqual(scripted_fn(torch.tensor(1)), (2, 3)) tuple_graph = scripted_fn.graph - slices = tuple_graph.findAllNodes("prim::TupleSlice") + slices = tuple_graph.findAllNodes("prim::TupleConstruct") num_outputs = set(len(x.output().type().elements()) for x in slices) - # one tuple slice should have an output with 2 elements, other 4 - self.assertTrue(num_outputs == {2, 4}) + # there should be only one tupleSlice with length of 2 + self.assertTrue(num_outputs == {2}) self.run_pass('lower_all_tuples', tuple_graph) self.assertTrue('Tuple' not in str(tuple_graph)) @@ -12142,6 +12142,26 @@ def test_indexing_end_out_of_bounds(): self.assertEqual(test_indexing_end_out_of_bounds(), ()) + def test_stepped_tuple_slicing(self): + + def check_slicing_tuple(slicing, tuple_type, tuple): + template = dedent(""" + def func(x): + # type: ({}) -> Any + return x{} + """) + self._check_code(template.format(tuple_type, slicing), "func", [tuple]) + + check_slicing_tuple("[-3:3:2]", "Tuple[int, int, int]", (0, 1, 2)) + check_slicing_tuple("[::55]", "Tuple[int, int, int, int, int]", (0, 1, 2, 3, 4)) + check_slicing_tuple("[:4:4]", "Tuple[int, int, int, int, int]", (0, 1, 2, 3, 4)) + check_slicing_tuple("[::-1]", "Tuple[int, int, int, int, int, int, int]", (0, 1, 2, 3, 4, 5, 6)) + check_slicing_tuple("[7:5:2]", "Tuple[int, int, int, int, int, int, int]", (0, 1, 2, 3, 4, 5, 6)) + check_slicing_tuple("[5:7:-2]", "Tuple[int, int, int, int, int, int, int]", (0, 1, 2, 3, 4, 5, 6)) + check_slicing_tuple("[::-2]", "Tuple[int, int, int, int, int]", (0, 1, 2, 3, 4)) + check_slicing_tuple("[:4:-3]", "Tuple[int, int, int, int, int, int]", (0, 1, 2, 3, 4, 5)) + check_slicing_tuple("[3::-2]", "Tuple[int, int, int, int, int]", (0, 1, 2, 3, 4)) + def test_lower_nested_tuples(self): @torch.jit.script def test(): diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl index 9121b7c84537..8b6374e9d71c 100644 --- a/tools/build_variables.bzl +++ b/tools/build_variables.bzl @@ -96,6 +96,7 @@ core_sources_common = [ "torch/csrc/jit/runtime/jit_exception.cpp", "torch/csrc/jit/runtime/operator.cpp", "torch/csrc/jit/runtime/print_handler.cpp", + "torch/csrc/jit/runtime/slice_indices_adjust.cpp", "torch/csrc/jit/runtime/register_ops_utils.cpp", "torch/csrc/jit/runtime/vararg_functions.cpp", "torch/csrc/jit/serialization/unpickler.cpp", diff --git a/torch/csrc/jit/frontend/ir_emitter.cpp b/torch/csrc/jit/frontend/ir_emitter.cpp index a21041343eee..02ead1d6fa80 100644 --- a/torch/csrc/jit/frontend/ir_emitter.cpp +++ b/torch/csrc/jit/frontend/ir_emitter.cpp @@ -21,6 +21,7 @@ #include #include #include +#include #include #include @@ -3443,7 +3444,25 @@ struct to_ir { } else { AT_ASSERT(!sliceable->type()->isSubtypeOf(TensorType::get())); } + // TODO for now let's deal with TupleType first. Ideally all list, tensor, + // string, and tuple slicing should be same (tugsbayasgalan) + if (sliceable->type()->cast()) { + std::vector> tuple_args; + // since we are only dealing with tuple slicing for now, we try to keep + // tuple args seperate for now + tuple_args.reserve(3); + + start ? tuple_args.emplace_back(start) + : tuple_args.emplace_back(c10::nullopt); + end ? tuple_args.emplace_back(end) + : tuple_args.emplace_back(c10::nullopt); + step ? tuple_args.emplace_back(step) + : tuple_args.emplace_back(c10::nullopt); + + return emitTupleSlice(loc, args[0], tuple_args); + } + // TODO this needs to be cleaned for list slicing // Default value for start is 0. if (!start) { start = graph->insertConstant(0, loc); @@ -3453,19 +3472,6 @@ struct to_ir { if (end) { args.emplace_back(loc, "end", end); } - if (sliceable->type()->cast()) { - if (step) { - // TODO: add support for slicing tuples with a step - throw ErrorReport(loc) - << "Unsupported operation: slicing tuples with a step isn't supported"; - } - - if (end) { - return emitTupleSlice(loc, args[0], args[1], /*end*/ args[2]); - } else { - return emitTupleSlice(loc, args[0], args[1], c10::nullopt); - } - } if (!step) { step = graph->insertConstant(1, loc); @@ -3828,28 +3834,37 @@ struct to_ir { Value* emitTupleSlice( const SourceRange& loc, const NamedValue& tuple_val, - const NamedValue& beg_val, - const at::optional& end_val) { + const std::vector>& tuple_args) { auto tuple_type = tuple_val.value(*graph)->type()->expect(); - int64_t beg = getAdjTupleIndex( - loc, - tuple_type, - getSliceInd(beg_val.value(*graph), loc), - /*allow_out_of_bounds*/ true); - int64_t end; int64_t tuple_len = tuple_type->elements().size(); + auto beg_val = tuple_args[0]; + auto end_val = tuple_args[1]; + auto step = tuple_args[2]; + + int64_t step_size = 1; + if (step) { + auto val = toIValue(step->value(*graph)); + TORCH_CHECK(val->isInt(), "Step size should always be an integer"); + step_size = val->to(); + } + + int64_t beg = std::numeric_limits::max(); + if (beg_val) { + beg = getAdjTupleIndex( + loc, tuple_type, getSliceInd(beg_val->value(*graph), loc), true); + } + + int64_t end = std::numeric_limits::max(); if (end_val) { end = getAdjTupleIndex( loc, tuple_type, getSliceInd(end_val->value(*graph), loc), true); - } else { - end = tuple_len; } - // slicing does not throw out of bounds errors - end = std::min(std::max((int64_t)0, end), tuple_len); - beg = std::min(std::max((int64_t)0, beg), tuple_len); + + int64_t num_values = slice_indices_adjust(tuple_len, &beg, &end, step_size); return graph - ->insertNode(graph->createTupleSlice(tuple_val.value(*graph), beg, end)) + ->insertNode(graph->createTupleSlice( + tuple_val.value(*graph), beg, step_size, num_values)) ->output(); } @@ -3873,19 +3888,25 @@ struct to_ir { auto s_tuple_val = sv->asTupleValue(val_range, method)->asValue(val_range, method); const SliceExpr& slice = SliceExpr(subscript_exprs[0]); + std::vector> tuple_args; + tuple_args.reserve(3); auto begin = NamedValue(val_range, "begin", emitExpr(Expr(slice.startOr(0)))); + tuple_args.emplace_back(begin); if (slice.end().present()) { auto end = NamedValue(val_range, "end", emitExpr(Expr(slice.end().get()))); - auto tupleSliceValue = - emitTupleSlice(val_range, s_tuple_val, begin, end); - return std::make_shared(tupleSliceValue); + tuple_args.emplace_back(end); + } else { - auto tupleSliceValue = - emitTupleSlice(val_range, s_tuple_val, begin, c10::nullopt); - return std::make_shared(tupleSliceValue); + tuple_args.emplace_back(c10::nullopt); } + // pushing step_size to match the tuple_args + tuple_args.emplace_back(c10::nullopt); + + auto tupleSliceValue = + emitTupleSlice(val_range, s_tuple_val, tuple_args); + return std::make_shared(tupleSliceValue); } else { return std::make_shared(emitBasicSlice( range, sv->asValue(val_range, method), subscript_exprs)); diff --git a/torch/csrc/jit/ir/ir.cpp b/torch/csrc/jit/ir/ir.cpp index 4714a6ae12f6..65b410d82069 100644 --- a/torch/csrc/jit/ir/ir.cpp +++ b/torch/csrc/jit/ir/ir.cpp @@ -1606,17 +1606,25 @@ Node* Graph::createTupleIndex( return n; } -Node* Graph::createTupleSlice(Value* tup, int64_t beg, int64_t end) { - auto n = create(prim::TupleSlice, {tup}); - auto tuple_type = tup->type()->expect(); - n->i_(attr::beg, beg); - n->i_(attr::end, end); - std::vector output_types; - for (auto i = beg; i < end; ++i) { - output_types.push_back(tuple_type->elements().at(i)); - } - auto tt = TupleType::create(std::move(output_types)); - n->output()->setType(tt); +Node* Graph::createTupleSlice( + Value* tup, + int64_t beg, + int64_t step_size, + int64_t num_values) { + std::vector new_vals; + TupleTypePtr tt = tup->type()->expect(); + new_vals.reserve(num_values); + + int64_t i = beg; + for (int64_t j = 0; j < num_values; ++j) { + auto idx = insertConstant(IValue(static_cast(i))); + auto tupleIndex = insertNode(createTupleIndex(tup, idx, tt->elements()[i])); + + new_vals.push_back(tupleIndex->output()); + i += step_size; + } + + auto n = createTuple(new_vals); return n; } diff --git a/torch/csrc/jit/ir/ir.h b/torch/csrc/jit/ir/ir.h index b20d5611c55c..7587451d9fd4 100644 --- a/torch/csrc/jit/ir/ir.h +++ b/torch/csrc/jit/ir/ir.h @@ -1122,7 +1122,11 @@ struct Graph { Value* tup, Value* idx, const TypePtr& output_type); - TORCH_API Node* createTupleSlice(Value* tup, int64_t beg, int64_t end); + TORCH_API Node* createTupleSlice( + Value* tup, + int64_t beg, + int64_t step_size, + int64_t num_values); TORCH_API Node* createEnumName(Value* e); TORCH_API Node* createEnumValue(Value* e); TORCH_API Node* createList( diff --git a/torch/csrc/jit/runtime/slice_indices_adjust.cpp b/torch/csrc/jit/runtime/slice_indices_adjust.cpp new file mode 100644 index 000000000000..e71d6ba94c9a --- /dev/null +++ b/torch/csrc/jit/runtime/slice_indices_adjust.cpp @@ -0,0 +1,56 @@ +#include +#include +#include + +namespace torch { +namespace jit { + +int64_t slice_indices_adjust( + int64_t length, + int64_t* start, + int64_t* stop, + int64_t step) { + TORCH_CHECK(step != 0, "List slice should have non-zero step") + TORCH_CHECK(step >= -INT64_MAX, "List slice step is out of bounds") + + // Comes from PySlice_Unpack. + if (*start == INT64_MAX) { + *start = (step < 0) ? INT64_MAX : 0; + } + if (*stop == INT64_MAX) { + *stop = (step < 0) ? INT64_MIN : INT64_MAX; + } + + // Comes from PySlice_AdjustIndices. + if (*start < 0) { + *start += length; + if (*start < 0) { + *start = (step < 0) ? -1 : 0; + } + } else if (*start >= length) { + *start = (step < 0) ? length - 1 : length; + } + + if (*stop < 0) { + *stop += length; + if (*stop < 0) { + *stop = (step < 0) ? -1 : 0; + } + } else if (*stop >= length) { + *stop = (step < 0) ? length - 1 : length; + } + + if (step < 0) { + if (*stop < *start) { + return (*start - *stop - 1) / (-step) + 1; + } + } else { + if (*start < *stop) { + return (*stop - *start - 1) / step + 1; + } + } + return 0; +} + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/runtime/slice_indices_adjust.h b/torch/csrc/jit/runtime/slice_indices_adjust.h new file mode 100644 index 000000000000..ea1e9511769d --- /dev/null +++ b/torch/csrc/jit/runtime/slice_indices_adjust.h @@ -0,0 +1,28 @@ +#pragma once + +#include +#include +#include + +namespace torch { +namespace jit { + +// Copyright (c) 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010, +// 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020 Python Software +// Foundation; All Rights Reserved +// +// Stolen (with appropriate modifications) by @agolynski +// (https://github.com/pytorch/pytorch/pull/33019) from cpython repo +// Objects/sliceobject.c with comment: this is harder to get right than you +// might think +// +// This adjusts indexes according to python list semantics and returns number +// of elements in the resulting list. +TORCH_API int64_t slice_indices_adjust( + int64_t length, + int64_t* start, + int64_t* stop, + int64_t step); + +} // namespace jit +} // namespace torch From f965b0fcfbcfac0a4cd699c8336ad271be86811e Mon Sep 17 00:00:00 2001 From: Shijun Kong Date: Fri, 11 Dec 2020 11:15:53 -0800 Subject: [PATCH 08/33] Expose run_async function on torch::jit::Method (#48607) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/48607 This change builds on top of https://github.com/pytorch/pytorch/pull/46865 further exposing the async interface to `torch::jit::Method`. added unit test for new `run_async` Test Plan: `buck test caffe2/test/cpp/jit/...` Reviewed By: dzhulgakov Differential Revision: D25219726 fbshipit-source-id: 89743c82a0baa1affe0254c1e2dbf873de8e5c76 --- test/cpp/jit/test_module_api.cpp | 38 ++++++++++++++++++++++++++++++++ torch/csrc/jit/api/method.h | 9 ++++++++ torch/csrc/jit/api/module.cpp | 11 +++++++++ 3 files changed, 58 insertions(+) diff --git a/test/cpp/jit/test_module_api.cpp b/test/cpp/jit/test_module_api.cpp index 910331166d51..c77d89af5afa 100644 --- a/test/cpp/jit/test_module_api.cpp +++ b/test/cpp/jit/test_module_api.cpp @@ -43,6 +43,44 @@ static void import_libs( si.loadType(QualifiedName(class_name)); } +TEST(ModuleAPITest, MethodRunAsync) { + // Module m("m"); + // m.define(R"( + // def forward(self): + // r1 = torch.jit.fork(torch.mm, torch.rand(100,100),torch.rand(100,100)) + // r2 = torch.jit.fork(torch.mm, torch.rand(100,100),torch.rand(100,100)) + // return r1.wait() + r2.wait() + // )"); + std::string filePath(__FILE__); + auto testModelFile = filePath.substr(0, filePath.find_last_of("/\\") + 1); + // borrow model file from TEST(GraphExecutorTest, runAsync_executor) + testModelFile.append("test_interpreter_async.pt"); + auto m = load(testModelFile); + + auto counter = 0; + std::mutex mtx; + + auto launcher = [&](std::function f) { + mtx.lock(); + ++counter; + mtx.unlock(); + at::launch(move(f)); + }; + + auto method = m.get_method("forward"); + + std::vector stack; + auto kwargs = std::unordered_map(); + auto future = method.run_async(stack, kwargs, launcher); + + future->wait(); + + // expect 2 forks and 2 wait callbacks being excuted on provided taskLauncher + // but ivalue::Future would be marked completed and release wait before + // finishing all callbacks + ASSERT_GE(counter, 2); +} + TEST(ModuleAPITest, Clone) { auto cu = std::make_shared(); // creating child module diff --git a/torch/csrc/jit/api/method.h b/torch/csrc/jit/api/method.h index 1d0ea9bce2c8..96b632b6b111 100644 --- a/torch/csrc/jit/api/method.h +++ b/torch/csrc/jit/api/method.h @@ -31,6 +31,15 @@ struct TORCH_API Method { std::vector stack, const Kwargs& kwargs = Kwargs()); + // Run method async. Invocation on this function would invokes a JIT + // interpreter that executes ops inline, one by one, on caller's thread. A + // model can utilize async op, i.e. `fork`, to launch an asynchronous task + // which will be launched on provided `taskLauncher`. + c10::intrusive_ptr run_async( + std::vector stack, + const Kwargs& kwargs = Kwargs(), + TaskLauncher taskLauncher = at::launch); + std::shared_ptr graph() const { return function_->graph(); } diff --git a/torch/csrc/jit/api/module.cpp b/torch/csrc/jit/api/module.cpp index 04eafc3d0f5d..d74905b5d3f0 100644 --- a/torch/csrc/jit/api/module.cpp +++ b/torch/csrc/jit/api/module.cpp @@ -118,6 +118,17 @@ IValue Method::operator()(std::vector stack, const Kwargs& kwargs) { return (*function_)(std::move(stack), kwargs); } +c10::intrusive_ptr Method::run_async( + std::vector stack, + const Kwargs& kwargs, + TaskLauncher taskLauncher) { + stack.insert(stack.begin(), owner()._ivalue()); + RECORD_TORCHSCRIPT_FUNCTION(name(), stack); + + function_->getSchema().checkAndNormalizeInputs(stack, kwargs); + return function_->runAsync(stack, std::move(taskLauncher)); +} + void Module::clone_method( const Module& orig, const Function& method, From 796b267763dee6e3451dacbf1c22c77b98ef91d9 Mon Sep 17 00:00:00 2001 From: Alexander Golynski Date: Fri, 11 Dec 2020 12:03:52 -0800 Subject: [PATCH 09/33] fix backwards compatibility for #48711 and its revert (#49240) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49240 Test Plan: Imported from OSS Reviewed By: heitorschueroff Differential Revision: D25500727 Pulled By: agolynski fbshipit-source-id: 6a690f52fe671267862b159b6330d37ef08ee291 --- test/backward_compatibility/check_backward_compatibility.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/backward_compatibility/check_backward_compatibility.py b/test/backward_compatibility/check_backward_compatibility.py index ccb4a6457537..e155537d7b99 100644 --- a/test/backward_compatibility/check_backward_compatibility.py +++ b/test/backward_compatibility/check_backward_compatibility.py @@ -187,6 +187,8 @@ ("aten::ifft", datetime.date(2021, 1, 31)), ("aten::irfft", datetime.date(2021, 1, 31)), ("aten::rfft", datetime.date(2021, 1, 31)), + ("aten::quantile", datetime.date(2021, 1, 31)), + ("aten::nanquantile", datetime.date(2021, 1, 31)), ] def allow_listed(schema, allow_list): From 2a3bb1cea0f2b070e66f032d26082a7a38e0e217 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Fri, 11 Dec 2020 12:11:13 -0800 Subject: [PATCH 10/33] [quant][graphmode][fx][fix] Fix typo in fusion (#49183) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49183 Test Plan: Imported from OSS Reviewed By: hx89 Differential Revision: D25473367 fbshipit-source-id: 0cd5e6769eeea0923dd104ea90b0192e3475b3ad --- torch/quantization/fx/fuse.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/quantization/fx/fuse.py b/torch/quantization/fx/fuse.py index b5cf78b05f33..5aabbd66c4b1 100644 --- a/torch/quantization/fx/fuse.py +++ b/torch/quantization/fx/fuse.py @@ -35,7 +35,7 @@ def fuse(self, model: GraphModule, self.modules = dict(input_root.named_modules()) additional_fusion_patterns = \ - fuse_custom_config_dict.get("additional_quant_pattern", {}) + fuse_custom_config_dict.get("additional_fusion_pattern", {}) fusion_patterns = get_combined_dict( get_default_fusion_patterns(), additional_fusion_patterns) # find fusion From 1cb5aa6c6039038095f3a505d7f3bdb5a2d6a1d4 Mon Sep 17 00:00:00 2001 From: Sebastian Messmer Date: Fri, 11 Dec 2020 12:36:15 -0800 Subject: [PATCH 11/33] Fix structured kernel codegen (#49244) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49244 see https://fb.quip.com/ceEdANd5iVsO RegisterMkldnnCPU kernels incorrectly used makeUnboxedOnly() calls to register add_.Tensor kernels. This is because the codegen incorrectly thought they're not c10-full. This PR fixes that. ghstack-source-id: 118411117 Test Plan: After this PR, RegisterMkldnnCPU doesn't contain the makeUnboxedOnly() calls anymore. Reviewed By: ezyang Differential Revision: D25500246 fbshipit-source-id: 8a8c2be9c4f4a5ce7eaae94257c2f8cbd176e92e --- tools/codegen/gen.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tools/codegen/gen.py b/tools/codegen/gen.py index af3ebbf674f4..8c22c1fe702c 100644 --- a/tools/codegen/gen.py +++ b/tools/codegen/gen.py @@ -471,6 +471,7 @@ def gen_one(f: NativeFunction) -> Optional[str]: return list(mapMaybe(gen_one, g.functions())) + @method_with_native_function def gen_unstructured(self, f: NativeFunction) -> Optional[str]: # for mypy type refinement; would be fixed by TODO on target assert self.target is not Target.DECLARATION From db5e5b439c454d657cfa8f08a096cf68e203f2a8 Mon Sep 17 00:00:00 2001 From: Ilia Cherniavskii Date: Fri, 11 Dec 2020 12:51:43 -0800 Subject: [PATCH 12/33] Extra sampling of record function events [resend] (#49114) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49114 resend of https://github.com/pytorch/pytorch/pull/48289 Test Plan: see 48289 Reviewed By: robieta Differential Revision: D25443365 Pulled By: ilia-cher fbshipit-source-id: c15ac312222bb4d744e10199ed79801cccae8227 --- aten/src/ATen/ThreadLocalState.cpp | 1 + aten/src/ATen/ThreadLocalState.h | 24 ++++- aten/src/ATen/core/dispatch/Dispatcher.h | 81 +++++++++------- aten/src/ATen/record_function.cpp | 117 ++++++++++++++++++----- aten/src/ATen/record_function.h | 27 +++++- binaries/record_function_benchmark.cc | 101 +++++++++---------- torch/csrc/autograd/function.h | 39 ++++---- torch/csrc/jit/runtime/interpreter.cpp | 5 +- 8 files changed, 268 insertions(+), 127 deletions(-) diff --git a/aten/src/ATen/ThreadLocalState.cpp b/aten/src/ATen/ThreadLocalState.cpp index 6d74e2f47ce0..3c7b9b6ff5bc 100644 --- a/aten/src/ATen/ThreadLocalState.cpp +++ b/aten/src/ATen/ThreadLocalState.cpp @@ -19,6 +19,7 @@ ThreadLocalState::ThreadLocalState(bool keep_grad_mode) grad_mode_enabled_ = GradMode::is_enabled(); } #endif + bumped_record_all_functions_ = at::checkRecordAllFunctions(); } /* static */ diff --git a/aten/src/ATen/ThreadLocalState.h b/aten/src/ATen/ThreadLocalState.h index f0cb85f0ff84..3c9b55b3d8d6 100644 --- a/aten/src/ATen/ThreadLocalState.h +++ b/aten/src/ATen/ThreadLocalState.h @@ -38,6 +38,9 @@ class TORCH_API ThreadLocalState { bool grad_mode_enabled_; #endif + // Whether pre-sampling RecordFunction optimization was enabled + bool bumped_record_all_functions_ = false; + friend class ThreadLocalStateGuard; }; @@ -45,7 +48,21 @@ class TORCH_API ThreadLocalState { class TORCH_API ThreadLocalStateGuard { public: explicit ThreadLocalStateGuard(const ThreadLocalState& state) - : prev_state_(ThreadLocalState()) { + : prev_state_(ThreadLocalState()), + bumped_record_all_functions_(state.bumped_record_all_functions_) { + // Special handling of RecordFunction pre-sampling optimization: + // pre-samping is enabled (bumped) when there're non-sampled + // (or high-frequency) global or TLS callbacks. + // + // ThreadLocalStateGuard simply resets RecordFunction's TLS and + // hence its thread local callbacks. + // + // Checking if the pre-sampling was enabled and preserving it in the + // async task by calling bumpRecordAllFunctions() and the corresponding + // releaseRecordAllFunctions() + if (bumped_record_all_functions_) { + at::bumpRecordAllFunctions(); + } // set the given state across the thread boundary ThreadLocalState::setThreadLocalState(state); } @@ -53,10 +70,15 @@ class TORCH_API ThreadLocalStateGuard { ~ThreadLocalStateGuard() { // restore previously set variables ThreadLocalState::setThreadLocalState(prev_state_); + if (bumped_record_all_functions_) { + at::releaseRecordAllFunctions(); + } } private: const ThreadLocalState prev_state_; + // Whether pre-sampling RecordFunction optimization was enabled + bool bumped_record_all_functions_ = false; }; template diff --git a/aten/src/ATen/core/dispatch/Dispatcher.h b/aten/src/ATen/core/dispatch/Dispatcher.h index 632739053c42..f83302e2d819 100644 --- a/aten/src/ATen/core/dispatch/Dispatcher.h +++ b/aten/src/ATen/core/dispatch/Dispatcher.h @@ -371,28 +371,39 @@ inline Return Dispatcher::callWithDispatchKey(const TypedOperatorHandleop.lookup(dispatchKey); #ifndef PYTORCH_DISABLE_PER_OP_PROFILING - // Check if we need to run callbacks registered with RecordFunction - // If true and callbacks need inputs, we box the arguments and pass - // them into the callbacks and also into the kernel call - - // Note: for perf reasons we wouldn't want to pass arguments into - // the function call or prematurely box them - at::RecordFunction guard(at::RecordScope::FUNCTION); - if (C10_UNLIKELY(guard.isActive())) { - if (shouldRecord(dispatchKey) && op.operatorIterator_->op.isObserved()) { - int64_t seq_num = -1; - // Setting sequence number in the Autograd case to associate - // the forward range with the coresponding Autograd's node - if (isIncludedInAlias(dispatchKey, DispatchKey::Autograd) && at::GradMode::is_enabled()) { - seq_num = at::sequence_number::peek(); - } - if (guard.needsInputs()) { - torch::jit::Stack stack = impl::boxArgs(args...); - guard.before(op, stack, seq_num); - } else { - guard.before(op, seq_num); + // By default, when there're no high-frequency or non-sampled callbacks, + // RecordFunction is pre-sampled as a perf optimization; + // shouldRunRecordFunction checks whether RecordFunction should be executed, + // and sets pre_sampled boolean argument value to whether pre-sampling was used - + // this boolean is passed into RecordFunction to adjust the sampling rates of + // the callbacks + bool pre_sampled = false; + if (C10_UNLIKELY(at::shouldRunRecordFunction(&pre_sampled))) { + // Check if we need to run callbacks registered with RecordFunction + // If true and callbacks need inputs, we box the arguments and pass + // them into the callbacks and also into the kernel call + + // Note: for perf reasons we wouldn't want to pass arguments into + // the function call or prematurely box them + at::RecordFunction guard(at::RecordScope::FUNCTION, pre_sampled); + if (C10_UNLIKELY(guard.isActive())) { + if (shouldRecord(dispatchKey) && op.operatorIterator_->op.isObserved()) { + int64_t seq_num = -1; + // Setting sequence number in the Autograd case to associate + // the forward range with the coresponding Autograd's node + if (isIncludedInAlias(dispatchKey, DispatchKey::Autograd) && at::GradMode::is_enabled()) { + seq_num = at::sequence_number::peek(); + } + if (guard.needsInputs()) { + torch::jit::Stack stack = impl::boxArgs(args...); + guard.before(op, stack, seq_num); + } else { + guard.before(op, seq_num); + } } } + // keeping the guard alive while executing the kernel + return kernel.template call(op, std::forward(args)...); } #endif // PYTORCH_DISABLE_PER_OP_PROFILING return kernel.template call(op, std::forward(args)...); @@ -429,20 +440,26 @@ inline void Dispatcher::callBoxed(const OperatorHandle& op, Stack* stack) const const auto& kernel = entry.lookup(dispatchKey); #ifndef PYTORCH_DISABLE_PER_OP_PROFILING - // using already existing stack to record function execution in observers - at::RecordFunction guard(at::RecordScope::FUNCTION); - if (C10_UNLIKELY(guard.isActive())) { - if (shouldRecord(dispatchKey) && entry.isObserved()) { - int64_t seq_num = -1; - if (isIncludedInAlias(dispatchKey, DispatchKey::Autograd) && at::GradMode::is_enabled()) { - seq_num = at::sequence_number::peek(); - } - if (guard.needsInputs()) { - guard.before(op, *stack, seq_num); - } else { - guard.before(op, seq_num); + bool pre_sampled = false; + if (C10_UNLIKELY(at::shouldRunRecordFunction(&pre_sampled))) { + // using already existing stack to record function execution in observers + at::RecordFunction guard(at::RecordScope::FUNCTION, pre_sampled); + if (C10_UNLIKELY(guard.isActive())) { + if (shouldRecord(dispatchKey) && entry.isObserved()) { + int64_t seq_num = -1; + if (isIncludedInAlias(dispatchKey, DispatchKey::Autograd) && at::GradMode::is_enabled()) { + seq_num = at::sequence_number::peek(); + } + if (guard.needsInputs()) { + guard.before(op, *stack, seq_num); + } else { + guard.before(op, seq_num); + } } } + // keeping the guard alive while executing the kernel + kernel.callBoxed(op, stack); + return; } #endif // PYTORCH_DISABLE_PER_OP_PROFILING kernel.callBoxed(op, stack); diff --git a/aten/src/ATen/record_function.cpp b/aten/src/ATen/record_function.cpp index 102931fd4aa7..d1b0acb87c28 100644 --- a/aten/src/ATen/record_function.cpp +++ b/aten/src/ATen/record_function.cpp @@ -30,8 +30,6 @@ std::atomic defaultNodeId(-1); std::atomic next_thread_id_ {0}; thread_local uint64_t current_thread_id_ = 0; -thread_local bool tls_record_function_enabled_ = true; - // Low probability constant static const double kLowProb = 0.001; struct CoinflipTLS { @@ -68,6 +66,10 @@ void set_record_function_tls_(const RecordFunctionTLS& tls) { class CallbackManager { public: CallbackHandle addThreadLocalCallback(RecordFunctionCallback cb) { + if (cb.samplingProb() > kLowProb) { + // pre-sampling of RecordFunction with prob. kLowProb cannot be used + at::bumpRecordAllFunctions(); + } // note: monotonically increasing callbacks_unique_id keeps // sorted_tls_callbacks_ sorted auto handle = next_unique_callback_handle(); @@ -76,6 +78,10 @@ class CallbackManager { } CallbackHandle addGlobalCallback(RecordFunctionCallback cb) { + if (cb.samplingProb() > kLowProb) { + // pre-sampling of RecordFunction with prob. kLowProb cannot be used + at::bumpRecordAllFunctions(); + } auto handle = next_unique_callback_handle(); sorted_global_callbacks_.emplace_back(std::move(cb), handle); return handle; @@ -92,6 +98,10 @@ class CallbackManager { return el.second == handle; }); if (it != cbs.end()) { + if (it->first.samplingProb() > kLowProb) { + // try to restore pre-sampling of RecordFunction + at::releaseRecordAllFunctions(); + } // keeps it sorted cbs.erase(it); return true; @@ -127,7 +137,13 @@ class CallbackManager { // callbackShouldRun is even hotter because it's called multiple // times per init(). Profiling shows that the function prologue is // taking up a significant fraction of the time. - static bool C10_ALWAYS_INLINE callbackShouldRun(const RecordFunctionCallback& cb, RecordScope scope) { + static bool C10_ALWAYS_INLINE callbackShouldRun( + const RecordFunctionCallback& cb, RecordScope scope, bool pre_sampled) { + TORCH_INTERNAL_ASSERT( + !pre_sampled || (cb.sampling_prob_ <= kLowProb), + "Incorrect usage of a pre-sampled RecordFunction with a high-frequency " + " or non-sampled callback"); + // first check whether this callback is interested in // the given scope type if (!cb.checkScope(scope)) { @@ -138,36 +154,45 @@ class CallbackManager { return cb.should_run_(cb); } - if (cb.sampling_prob_ == 1.0) { - return true; + // otherwise potentially do the sampling + double sampling_prob = cb.sampling_prob_; + if (pre_sampled) { + // adjust the sampling rate to account for kLowProb pre-sampling of + // the RecordFunction + sampling_prob /= kLowProb; } - // model the low probability events as events happening - // with probability kLowProb followed by another sampling with - // probability (sampling_prob__ / kLowProb), then replace the coin - // flip for kLowProb with a thread local number of tries tries_left_ - // sampled from the geometric distribution. - if (cb.sampling_prob_ < kLowProb) { - if (coinflip_tls_.tries_left_ == 0) { - coinflip_tls_.tries_left_ = sample_geometric(); - return (sample_zero_one() < cb.sampling_prob_ / kLowProb); + + if (sampling_prob < 1.0) { + // model the low probability events as events happening + // with probability kLowProb followed by another sampling with + // probability (sampling_prob / kLowProb), then replace the coin + // flip for kLowProb with a thread local number of tries tries_left_ + // sampled from the geometric distribution. + if (sampling_prob < kLowProb) { + if (coinflip_tls_.tries_left_ == 0) { + coinflip_tls_.tries_left_ = sample_geometric(); + return (sample_zero_one() < sampling_prob / kLowProb); + } else { + --coinflip_tls_.tries_left_; + return false; + } } else { - --coinflip_tls_.tries_left_; - return false; + return (sample_zero_one() < sampling_prob); } - } else { - return (sample_zero_one() < cb.sampling_prob_); } + + return true; } // init is called by RecordFunction in constructor to // determine which thread local and global callbacks are going // to be executed and whether any of them need inputs - inline void init(RecordFunction& rec_fn, RecordScope scope) { + inline void init(RecordFunction& rec_fn, RecordScope scope, bool pre_sampled) { bool found_needs_inputs = false; bool found_needs_ids = false; for (const auto& cb: rf_tls_.sorted_tls_callbacks_) { - if (callbackShouldRun(cb.first, scope)) { + if (callbackShouldRun(cb.first, scope, pre_sampled)) { if (cb.first.needsInputs()) { found_needs_inputs = true; } @@ -182,7 +207,7 @@ class CallbackManager { } for (const auto& cb: sorted_global_callbacks_) { - if (callbackShouldRun(cb.first, scope)) { + if (callbackShouldRun(cb.first, scope, pre_sampled)) { if (cb.first.needsInputs()) { found_needs_inputs = true; } @@ -308,7 +333,6 @@ namespace { } } // namespace - RecordFunctionCallbacks _getTLSCallbacks() { return rf_tls_.sorted_tls_callbacks_; } @@ -374,12 +398,12 @@ void enableRecordFunction(bool enable) { rf_tls_.tls_record_function_enabled_ = enable; } -RecordFunction::RecordFunction(RecordScope scope) { +RecordFunction::RecordFunction(RecordScope scope, bool pre_sampled) { auto* rf_tls_ptr = &rf_tls_; if (rf_tls_ptr->tls_record_function_enabled_) { auto& m = manager(); if (!m.sorted_global_callbacks_.empty() || !rf_tls_ptr->sorted_tls_callbacks_.empty()) { - m.init(*this, scope); + m.init(*this, scope, pre_sampled); } } } @@ -451,4 +475,49 @@ void RecordFunction::end() { } } +// RecordFunction pre-sampling +namespace { +// Whether to try to create RecordFunction on each call (>0) or +// use pre-sampling (=0) +std::atomic global_record_all_functions_ {0}; +} + +void bumpRecordAllFunctions() { + global_record_all_functions_.fetch_add(1, std::memory_order_relaxed); +} + +void releaseRecordAllFunctions() { + TORCH_CHECK(global_record_all_functions_.fetch_sub(1, std::memory_order_relaxed) >= 0); +} + +bool checkRecordAllFunctions() { + return (global_record_all_functions_.load(std::memory_order_relaxed) > 0); +} + +bool shouldRunRecordFunction(bool* pre_sampled) { + auto* rf_tls_ptr = &rf_tls_; + if (rf_tls_ptr->sorted_tls_callbacks_.empty() && !manager().hasGlobalCallbacks()) { + *pre_sampled = false; + return false; + } + if (global_record_all_functions_.load(std::memory_order_relaxed) > 0) { + *pre_sampled = false; + return true; + } + if (!rf_tls_ptr->tls_record_function_enabled_) { + *pre_sampled = false; + return false; + } + + *pre_sampled = true; + auto* coinflip_tls_ptr = &coinflip_tls_; + if (coinflip_tls_ptr->tries_left_ == 0) { + coinflip_tls_ptr->tries_left_ = sample_geometric(); + return true; + } else { + --coinflip_tls_ptr->tries_left_; + return false; + } +} + } // namespace at diff --git a/aten/src/ATen/record_function.h b/aten/src/ATen/record_function.h index 43c2d878840d..bcd0fbc37e77 100644 --- a/aten/src/ATen/record_function.h +++ b/aten/src/ATen/record_function.h @@ -92,8 +92,11 @@ typedef uint64_t RecordFunctionHandle; struct TORCH_API RecordFunction { // Default constructor is used with before function called afterwards: // scope - record scope that this function tracks + // pre_sampled - whether this RecordFunction was already pre-sampled with + // kLowProb probability RecordFunction( - RecordScope scope = RecordScope::FUNCTION); + RecordScope scope = RecordScope::FUNCTION, + bool pre_sampled = false); template void before( @@ -240,6 +243,9 @@ struct TORCH_API RecordFunction { // flag is used to check whether the start callbacks were called bool called_start_callbacks_ = false; + // Whether the RecordFunction is pre-sampled + bool pre_sampled_ = false; + // Used internally to keep track of thread local and global callbacks // that were picked to run; must be sorted; CallbackHandles sorted_active_tls_handles_; @@ -332,7 +338,7 @@ class TORCH_API RecordFunctionCallback { } RecordFunctionCallback& samplingProb(double sampling_prob) { - TORCH_CHECK(sampling_prob >= 0.0 && sampling_prob_ <= 1.0, + TORCH_CHECK(sampling_prob >= 0.0 && sampling_prob <= 1.0, "Invalid sampling probability"); sampling_prob_ = sampling_prob; return *this; @@ -546,10 +552,27 @@ struct TORCH_API RecordFunctionTLS { RecordFunctionCallbacks sorted_tls_callbacks_; bool tls_record_function_enabled_ = true; + + // Stores the number of coin flips before the next successful coin flip + int tries_left_ = 0; }; TORCH_API const RecordFunctionTLS& get_record_function_tls_(); TORCH_API void set_record_function_tls_(const RecordFunctionTLS& tls); +// Checks whether RecordFunction should be called, +// sets boolean pointed by the argument to whether pre-sampling was used +TORCH_API bool shouldRunRecordFunction(bool*); + +// The following functions are used to disable/enable pre-sampling of RecordFunction +// when high-frequency/non-sampled callbacks are added/removed. +// Note: every call to bumpRecordAllFunctions() is supposed to be matched with +// the corresponding releaseRecordAllFunctions() call. +// Note: disabling pre-sampling of RecordFunction incurs an extra overhead, since +// RecordFunction will be created for each operator call. +TORCH_API void bumpRecordAllFunctions(); +TORCH_API void releaseRecordAllFunctions(); +TORCH_API bool checkRecordAllFunctions(); + } // namespace at diff --git a/binaries/record_function_benchmark.cc b/binaries/record_function_benchmark.cc index d924003b9270..53a8bd16f43d 100644 --- a/binaries/record_function_benchmark.cc +++ b/binaries/record_function_benchmark.cc @@ -7,61 +7,55 @@ #include #include -C10_DEFINE_int(iter, 100, "Number of iterations"); -C10_DEFINE_int(warmup_iter, 10, "Number of warmup iterations"); +C10_DEFINE_int(iter, 10000, "Number of iterations"); C10_DEFINE_int(sampled_iter, 10e6, "Number of iterations for the sampled observer benchmark"); namespace { -const int kInnerIter = 100; -const int kNumSampledCb = 2; const int kTensorSize = 16; const int kSmallTensorSize = 1; -const float kSampingProb = 0.1; - const float kLowSamplingProb = 0.0001; } -void setupBenchmarkCallbacks() { - at::enableRecordFunction(); - at::clearCallbacks(); - // non-sampled callback - at::addGlobalCallback(at::RecordFunctionCallback( - [&](const at::RecordFunction& fn) {}, +void addTestCallback( + double sampling_prob = 1.0, + std::function fn = + [](const at::RecordFunction&) {}) { + auto cb = at::RecordFunctionCallback( + std::move(fn), [](const at::RecordFunction&) {}) - .needsInputs(true)); - - // sampled - for (auto idx = 0; idx < kNumSampledCb; ++idx) { - at::addGlobalCallback(at::RecordFunctionCallback( - [](const at::RecordFunction& fn) {}, - [](const at::RecordFunction&) {}) - .needsInputs(true) - .samplingProb(kSampingProb) - ); + .needsInputs(false); + if (sampling_prob < 1.0) { + cb.samplingProb(sampling_prob); } + at::addGlobalCallback(cb); } -float runTensorBench(int tensor_size, int outer_iter) { +float runTensorGEMMBench(int tensor_size, int iter) { typedef std::chrono::high_resolution_clock clock; typedef std::chrono::microseconds us; std::chrono::time_point start_time = clock::now(); - for (auto idx = 0; idx < kInnerIter * outer_iter; ++idx) { - torch::mm( - torch::randn({tensor_size, tensor_size}), - torch::randn({tensor_size, tensor_size})); + auto inp = torch::randn({tensor_size, tensor_size}); + for (auto idx = 0; idx < iter; ++idx) { + torch::mm(inp, inp); } auto duration = static_cast( std::chrono::duration_cast(clock::now() - start_time).count()); return duration; } -float runPureRecordFunctionBench(int outer_iter) { +float runPureRecordFunctionBench(int iter) { typedef std::chrono::high_resolution_clock clock; typedef std::chrono::microseconds us; std::chrono::time_point start_time = clock::now(); - for (auto n = 0; n < outer_iter; ++n) { - RECORD_USER_SCOPE("test"); + for (auto idx = 0; idx < iter; ++idx) { + bool pre_sampled = false; + if (at::shouldRunRecordFunction(&pre_sampled)) { + at::RecordFunction guard(at::RecordScope::USER_SCOPE, pre_sampled); + if (C10_UNLIKELY(guard.isActive())) { + guard.before("Test", -1); + } + } } auto duration = static_cast( std::chrono::duration_cast(clock::now() - start_time).count()); @@ -71,18 +65,19 @@ float runPureRecordFunctionBench(int outer_iter) { void runBenchmark() { float duration = 0; for (auto tensor_size : std::set({kSmallTensorSize, kTensorSize})) { - duration = runTensorBench(tensor_size, FLAGS_iter); - std::cout << "Running tensor benchmark, time per iteration (" + duration = runTensorGEMMBench(tensor_size, FLAGS_iter); + std::cout << "Tensor GEMM benchmark (" << tensor_size << "x" << tensor_size - << "): " << (duration/FLAGS_iter) + << ", " << FLAGS_iter << "): " << duration << " us." << std::endl; } - duration = runPureRecordFunctionBench(FLAGS_iter * 100); - std::cout << "Running pure RecordFunction benchmark, time per iteration: " - << (duration/FLAGS_iter) - << " us." << std::endl; + duration = runPureRecordFunctionBench(FLAGS_iter); + std::cout << "Pure RecordFunction benchmark (" + << FLAGS_iter << "): " + << duration + << " us." << std::endl; } int main(int argc, char** argv) { @@ -91,32 +86,38 @@ int main(int argc, char** argv) { return -1; } - auto duration = runTensorBench(kSmallTensorSize, FLAGS_warmup_iter); - std::cout << "Warmup time: " << duration << " us." << std::endl; + at::enableRecordFunction(); + at::clearCallbacks(); - setupBenchmarkCallbacks(); - std::cout << "Running with empty observers" << std::endl; + std::cout << "Warm up" << std::endl; runBenchmark(); - at::clearCallbacks(); std::cout << "Running without observers" << std::endl; runBenchmark(); - std::cout << "Running sampled observer benchmark" << std::endl; + addTestCallback(); + std::cout << "Running with empty non-sampled observer" << std::endl; + runBenchmark(); + at::clearCallbacks(); + + addTestCallback(kLowSamplingProb); + std::cout << "Running with empty sampled observer" << std::endl; + runBenchmark(); + at::clearCallbacks(); + + std::cout << "Checking number of sampled observer invocations" << std::endl; int cb_count = 0; - at::addGlobalCallback(at::RecordFunctionCallback( + addTestCallback( + kLowSamplingProb, [&](const at::RecordFunction& fn) { ++cb_count; - }, - [](const at::RecordFunction&) {}) - .needsInputs(true) - .samplingProb(kLowSamplingProb) + } ); - runPureRecordFunctionBench(FLAGS_sampled_iter); + auto duration = runPureRecordFunctionBench(FLAGS_sampled_iter); std::cout << "Pure RecordFunction runtime of " << FLAGS_sampled_iter - << " iterations " << duration + << " iterations: " << duration << " us, number of callback invocations: " << cb_count << ", expected number: ~" << (int)(FLAGS_sampled_iter * kLowSamplingProb) << " invocations" << std::endl; diff --git a/torch/csrc/autograd/function.h b/torch/csrc/autograd/function.h index 09dc048f214b..44171e1a3b1b 100644 --- a/torch/csrc/autograd/function.h +++ b/torch/csrc/autograd/function.h @@ -133,26 +133,33 @@ struct TORCH_API Node : std::enable_shared_from_this { /// Evaluates the function on the given inputs and returns the result of the /// function call. variable_list operator()(variable_list&& inputs) { - // Using RecordFunction to trogger observers in the backward pass - at::RecordFunction guard(at::RecordScope::BACKWARD_FUNCTION); - if (guard.isActive()) { - // Using sequence number and thread id to correlate with - // the forward pass function - guard.setForwardThreadId(thread_id_); - if (guard.needsInputs()) { - guard.before( - name(), - std::vector(inputs.begin(), inputs.end()), - sequence_nr()); - } else { - guard.before(name(), sequence_nr()); - } - } // In the first iteration of named tensors, autograd ignores names and // operates on unnamed tensors. In the long term, autograd should // probably operate with names. at::NoNamesGuard no_names_guard; - return apply(std::move(inputs)); + + bool pre_sampled = false; + if (at::shouldRunRecordFunction(&pre_sampled)) { + // Using RecordFunction to trogger observers in the backward pass + at::RecordFunction guard(at::RecordScope::BACKWARD_FUNCTION, pre_sampled); + if (guard.isActive()) { + // Using sequence number and thread id to correlate with + // the forward pass function + guard.setForwardThreadId(thread_id_); + if (guard.needsInputs()) { + guard.before( + name(), + std::vector(inputs.begin(), inputs.end()), + sequence_nr()); + } else { + guard.before(name(), sequence_nr()); + } + } + // keeping stack guard object alive during the call + return apply(std::move(inputs)); + } else { + return apply(std::move(inputs)); + } } // Graph Connectivity API diff --git a/torch/csrc/jit/runtime/interpreter.cpp b/torch/csrc/jit/runtime/interpreter.cpp index 3a028175d9c3..5d88264a2f2c 100644 --- a/torch/csrc/jit/runtime/interpreter.cpp +++ b/torch/csrc/jit/runtime/interpreter.cpp @@ -1609,10 +1609,11 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { } static void checkAndStartRecordFunction(Frame& frame, Stack& stack) { + bool pre_sampled = false; if (!frame.record_function && at::hasCallbacks() && - at::isRecordFunctionEnabled()) { + at::shouldRunRecordFunction(&pre_sampled)) { auto rec_fn = std::make_unique( - at::RecordScope::TORCHSCRIPT_FUNCTION); + at::RecordScope::TORCHSCRIPT_FUNCTION, pre_sampled); if (rec_fn->isActive()) { if (rec_fn->needsInputs()) { rec_fn->before( From 9920adebfd2ff2eda33f72f2d4589973f1581b76 Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Fri, 11 Dec 2020 13:24:55 -0800 Subject: [PATCH 13/33] pyi cleanup (#49054) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49054 These are some followups from the first pyi codegen PR. Still maintaining byte-for-byte compatibility in this one. - Separated `argument_str() with a pyi flag into two functions, `argument_str()` and `argument_str_pyi()` - Added a notes section for pyi at the top of `python.py` - Added a `Python Interface` section that I moved the free-standing pyi functions to Test Plan: Imported from OSS Reviewed By: ljk53 Differential Revision: D25410848 Pulled By: bdhirsh fbshipit-source-id: db83a80af900c32b5e32d67ce27767f6e7c2adfb --- .jenkins/pytorch/codegen-test.sh | 1 - tools/autograd/gen_python_functions.py | 2 - tools/codegen/api/python.py | 261 ++++++++++++++----------- 3 files changed, 143 insertions(+), 121 deletions(-) diff --git a/.jenkins/pytorch/codegen-test.sh b/.jenkins/pytorch/codegen-test.sh index 44f1e9449bf0..17e7e9fa3445 100755 --- a/.jenkins/pytorch/codegen-test.sh +++ b/.jenkins/pytorch/codegen-test.sh @@ -37,7 +37,6 @@ python -m tools.setup_helpers.generate_code \ mkdir -p "$OUT"/pyi/torch/_C mkdir -p "$OUT"/pyi/torch/nn python -m tools.pyi.gen_pyi \ - --declarations-path "$OUT"/torch/share/ATen/Declarations.yaml \ --native-functions-path aten/src/ATen/native/native_functions.yaml \ --deprecated-functions-path tools/autograd/deprecated.yaml \ --out "$OUT"/pyi diff --git a/tools/autograd/gen_python_functions.py b/tools/autograd/gen_python_functions.py index 63438a527b4c..47abce5466c6 100644 --- a/tools/autograd/gen_python_functions.py +++ b/tools/autograd/gen_python_functions.py @@ -228,8 +228,6 @@ def signature_original(f: NativeFunction) -> str: opname = str(f.func.name.name.base) if f.func.is_out_fn(): opname += '_out' - # TODO: remove HACK - # I think we want to differentiate inplace functions here.. but we currently don't for the arg parser if f.func.name.name.inplace and pyi: opname += '_' args = CppSignatureGroup.from_schema(f.func, method=False).signature.arguments() diff --git a/tools/codegen/api/python.py b/tools/codegen/api/python.py index 45fa1685a5cf..dadfed354106 100644 --- a/tools/codegen/api/python.py +++ b/tools/codegen/api/python.py @@ -13,6 +13,8 @@ # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # +# [Notes] python binding codegen +# # The Python binding codegen produces code that takes the input list of # PyObjects, finds the matching ATen C++ function using PythonArgParser, # converts the PyObjects into C++ types and calls the ATen C++ function: @@ -171,25 +173,15 @@ # return wrap(dispatch_abs_out(_r.tensor(1), _r.tensor(0))); # } # - -# TODO: stick this more firmly in the data model somewhere? -def namedtuple_fieldnames(returns: Tuple[Return, ...]) -> List[str]: - if len(returns) <= 1 or all(map(lambda r: r.name is None, returns)): - return [] - else: - if any(map(lambda r: r.name is None, returns)): - # When building on Windows, `PyStructSequence_UnnamedField` could not be - # resolved by the linker for some reason, which cause error in building: - # - # python_nn_functions.cpp.obj : error LNK2001: unresolved external symbol - # PyStructSequence_UnnamedField - # - # Thus, at this point in time, we do not support unnamed - # fields in namedtuple; you must either name all fields, - # or none of them. - raise ValueError("Unnamed field is not supported by codegen") - - return list(map(lambda r: str(r.name), returns)) +# +# [Notes] python interface codegen +# The python dataclasses below are used used to generate both python binding code +# and pyi type hint signatures. +# In theory these two should look very similar, but there are number of differences +# in how pyi signatures vs. python_arg_parser signatures are generated. +# These differences have been encapsulated in signature_str() vs. signature_str_pyi() +# to display the full signatures, and argument_str() vs argument_str_pyi() to display arguments. +# For examples, only pyi signatures include return types. @dataclass(frozen=True) class PythonReturns: @@ -235,9 +227,30 @@ class PythonArgument: # Compute argument formal for python argument parsing. # Needs to be consistent with torch/csrc/utils/python_arg_parser.h. - def argument_str(self, *, method: bool = False, pyi: bool = False, deprecated: bool = False) -> str: - type_str = argument_type_str_pyi(self.type, pyi_out_arg=pyi and isinstance(self, PythonOutArgument)) \ - if pyi else argument_type_str(self.type) + def argument_str(self, *, method: bool = False) -> str: + type_str = argument_type_str(self.type) + + name = self.name + # s/self/input/ outside method bindings + # [old codegen] TODO: remove this? doesn't rename in codegen, it's just + # for the parse string + if name == 'self' and type_str == 'Tensor' and not method: + name = 'input' + + # add default + if self.default is not None: + default = { + 'nullptr': 'None', + 'c10::nullopt': 'None', + '{}': 'None', + }.get(self.default, self.default) + return f'{type_str} {name}={default}' + else: + return f'{type_str} {name}' + + def argument_str_pyi(self, *, method: bool = False, deprecated: bool = False) -> str: + is_out_arg = isinstance(self, PythonOutArgument) + type_str = argument_type_str_pyi(self.type, pyi_out_arg=is_out_arg) name = self.name # s/self/input/ outside method bindings @@ -246,45 +259,33 @@ def argument_str(self, *, method: bool = False, pyi: bool = False, deprecated: b if name == 'self' and type_str == 'Tensor' and not method and not deprecated: name = 'input' - if pyi: - if name == 'from': # from is a Python keyword... - name += '_' - # pyi merges the _out and functional variants into the same signature, with an optional out arg - if name == 'out' and type_str == 'Tensor' and not deprecated: - type_str = 'Optional[' + type_str + ']' + if name == 'from': # from is a Python keyword... + name += '_' + # pyi merges the _out and functional variants into the same signature, with an optional out arg + if name == 'out' and type_str == 'Tensor' and not deprecated: + type_str = 'Optional[' + type_str + ']' # TODO: remove diff. pyi deprecated signatures don't get defaults for their out arg - treat_as_no_default = pyi and deprecated and isinstance(self, PythonOutArgument) and self.default == 'None' + treat_as_no_default = deprecated and is_out_arg and self.default == 'None' # add default if self.default is not None and not treat_as_no_default: - if pyi: - if isinstance(self.type, ListType) and self.type.elem == BaseType(BaseTy.int) and \ - self.default.startswith('{') and self.default.endswith('}'): - default = '(' + self.default[1:-1] + ')' - else: - default = { - 'nullptr': 'None', - 'c10::nullopt': 'None', - '{}': 'None', - 'MemoryFormat::Contiguous': 'contiguous_format', - 'QScheme::PER_TENSOR_AFFINE': 'per_tensor_affine', - }.get(self.default, self.default) - # TODO: remove requires_grad special case (byte-for-byte compat) - return f'{name}:{type_str}={default}' if name == 'requires_grad' else f'{name}: {type_str}={default}' + if isinstance(self.type, ListType) and self.type.elem == BaseType(BaseTy.int) and \ + self.default.startswith('{') and self.default.endswith('}'): + default = '(' + self.default[1:-1] + ')' else: default = { 'nullptr': 'None', 'c10::nullopt': 'None', '{}': 'None', + 'MemoryFormat::Contiguous': 'contiguous_format', + 'QScheme::PER_TENSOR_AFFINE': 'per_tensor_affine', }.get(self.default, self.default) - return f'{type_str} {name}={default}' + # TODO: remove requires_grad special case (byte-for-byte compat) + return f'{name}:{type_str}={default}' if name == 'requires_grad' else f'{name}: {type_str}={default}' else: - if pyi: - # TODO: remove requires_grad special case (byte-for-byte compat) - return f'{name}:{type_str}' if name == 'requires_grad' else f'{name}: {type_str}' - else: - return f'{type_str} {name}' + # TODO: remove requires_grad special case (byte-for-byte compat) + return f'{name}:{type_str}' if name == 'requires_grad' else f'{name}: {type_str}' @dataclass(frozen=True) class PythonOutArgument(PythonArgument): @@ -391,8 +392,7 @@ def output_idx(self) -> int: # for error parsing. # # For a translation to mypy-valid type signatures, see - # signature_str_pyi. If you change any logic here, please - # check that file too. + # signature_str_pyi(). def signature_str(self, *, skip_outputs: bool = False) -> str: args = self.arguments(skip_outputs=skip_outputs) schema_formals: List[str] = list(map(lambda a: a.argument_str(method=self.method), args)) @@ -404,7 +404,7 @@ def signature_str(self, *, skip_outputs: bool = False) -> str: def signature_str_pyi(self, *, skip_outputs: bool = False, hacky_add_output: bool = False) -> str: args = self.arguments(skip_outputs=skip_outputs, hacky_add_output=hacky_add_output) - schema_formals: List[str] = list(map(lambda a: a.argument_str(method=self.method, pyi=True), args)) + schema_formals: List[str] = list(map(lambda a: a.argument_str_pyi(method=self.method), args)) positional_argc = len(self.input_args) if len(schema_formals) > positional_argc: schema_formals.insert(positional_argc, '*') @@ -419,7 +419,7 @@ def signature_str_pyi(self, *, skip_outputs: bool = False, hacky_add_output: boo def signature_str_pyi_vararg(self, *, skip_outputs: bool = False, hacky_add_output: bool = False) -> Optional[str]: # only pyi uses vararg signatures args = self.arguments(skip_outputs=skip_outputs, hacky_add_output=hacky_add_output) - schema_formals: List[str] = list(map(lambda a: a.argument_str(method=self.method, pyi=True), args)) + schema_formals: List[str] = list(map(lambda a: a.argument_str_pyi(method=self.method), args)) # vararg only applies to pyi signatures. vararg variants are not generated for all signatures num_args = self.arguments_count() num_positionalargs = len(self.input_args) @@ -471,7 +471,7 @@ def signature_str(self, *, skip_outputs: bool = False) -> str: def signature_str_pyi(self, *, skip_outputs: bool = False, hacky_add_output: bool = False) -> str: args = self.arguments(skip_outputs=skip_outputs, hacky_add_output=hacky_add_output) - schema_formals: List[str] = list(map(lambda a: a.argument_str(method=self.method, pyi=True, deprecated=True), args)) + schema_formals: List[str] = list(map(lambda a: a.argument_str_pyi(method=self.method, deprecated=True), args)) positional_argc = len(self.input_args) if len(schema_formals) > positional_argc: schema_formals.insert(positional_argc, '*') @@ -662,67 +662,6 @@ def argument(a: Argument) -> PythonArgument: default_init=None, ) -def argument_type_str_pyi(t: Type, *, pyi_out_arg: bool = False) -> str: - add_optional = False - if isinstance(t, OptionalType): - t = t.elem - add_optional = True - - if isinstance(t, BaseType): - if t.name == BaseTy.int: - ret = '_int' - elif t.name == BaseTy.float: - ret = '_float' - elif t.name == BaseTy.str: - ret = 'str' - elif t.name == BaseTy.Scalar: - ret = 'Number' - elif t.name == BaseTy.ScalarType: - ret = '_dtype' - elif t.name == BaseTy.bool: - ret = '_bool' - elif t.name == BaseTy.QScheme: - ret = '_qscheme' - elif t.name == BaseTy.Layout: - ret = '_layout' - elif t.name == BaseTy.Device: - ret = 'Union[_device, str, None]' - elif t.name == BaseTy.MemoryFormat: - ret = 'memory_format' - elif t.name == BaseTy.Dimname: - ret = 'Union[str, ellipsis, None]' - elif t.name in [BaseTy.Tensor, BaseTy.Generator, - BaseTy.Storage, BaseTy.Stream, BaseTy.str]: - # These python schema type names line up with their function schema names - ret = t.name.name - - elif isinstance(t, ListType): - if pyi_out_arg and t.is_tensor_like(): - # TODO remove HACK - # pyi blindly treats all tensor-like out args as having type Tensor - return 'Tensor' - if str(t.elem) == 'int': - ret = 'Union[_int, _size]' if t.size is not None else '_size' - elif t.is_tensor_like(): - # TODO: this doesn't seem right... - # Tensor?[] currently translates to Optional[Union[Tuple[Tensor, ...], List[Tensor]]] - # It should probably translate to Union[Tuple[Optional[Tensor], ...], List[Optional[Tensor]]] - if isinstance(t.elem, OptionalType): - add_optional = True - ret = 'Union[Tensor, Tuple[Tensor, ...], List[Tensor]]' if t.size is not None else \ - 'Union[Tuple[Tensor, ...], List[Tensor]]' - elif str(t.elem) == 'float': - ret = 'Sequence[float]' - else: - elem = argument_type_str_pyi(t.elem) - ret = f'Sequence[{elem}]' - - if add_optional: - ret = 'Optional[' + ret + ']' - return ret - - raise RuntimeError(f'unrecognized type {repr(t)}') - # Generates a PythonSignature that can be used for either .pyi or PythonArgParser codegen def signature(f: NativeFunction, *, method: bool = False, pyi: bool = False) -> PythonSignature: args: List[Argument] = [] @@ -770,7 +709,7 @@ def signature(f: NativeFunction, *, method: bool = False, pyi: bool = False) -> tensor_options_args.append(PythonArgument( name='dtype', type=BaseType(BaseTy.ScalarType), - default=_dtype_default_type_hack(name, pyi=pyi), + default='None' if pyi else _dtype_default_type_hack(name), default_init='self.scalar_type()' if is_like_or_new_function else None, )) # TODO: probably a bug, kill this diff? @@ -816,12 +755,98 @@ def signature(f: NativeFunction, *, method: bool = False, pyi: bool = False) -> ) # TODO blowtorch -def _dtype_default_type_hack(name: str, *, pyi: bool) -> str: - if not pyi and (name.startswith('randperm') or name == 'tril_indices' or name == 'triu_indices'): +def _dtype_default_type_hack(name: str) -> str: + if name.startswith('randperm') or name == 'tril_indices' or name == 'triu_indices': return 'torch.int64' else: return 'None' +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# Python Interface +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + +def namedtuple_fieldnames(returns: Tuple[Return, ...]) -> List[str]: + if len(returns) <= 1 or all(map(lambda r: r.name is None, returns)): + return [] + else: + if any(map(lambda r: r.name is None, returns)): + # When building on Windows, `PyStructSequence_UnnamedField` could not be + # resolved by the linker for some reason, which cause error in building: + # + # python_nn_functions.cpp.obj : error LNK2001: unresolved external symbol + # PyStructSequence_UnnamedField + # + # Thus, at this point in time, we do not support unnamed + # fields in namedtuple; you must either name all fields, + # or none of them. + raise ValueError("Unnamed field is not supported by codegen") + + return list(map(lambda r: str(r.name), returns)) + +def argument_type_str_pyi(t: Type, *, pyi_out_arg: bool = False) -> str: + add_optional = False + if isinstance(t, OptionalType): + t = t.elem + add_optional = True + + if isinstance(t, BaseType): + if t.name == BaseTy.int: + ret = '_int' + elif t.name == BaseTy.float: + ret = '_float' + elif t.name == BaseTy.str: + ret = 'str' + elif t.name == BaseTy.Scalar: + ret = 'Number' + elif t.name == BaseTy.ScalarType: + ret = '_dtype' + elif t.name == BaseTy.bool: + ret = '_bool' + elif t.name == BaseTy.QScheme: + ret = '_qscheme' + elif t.name == BaseTy.Layout: + ret = '_layout' + elif t.name == BaseTy.Device: + ret = 'Union[_device, str, None]' + elif t.name == BaseTy.MemoryFormat: + ret = 'memory_format' + elif t.name == BaseTy.Dimname: + ret = 'Union[str, ellipsis, None]' + elif t.name in [BaseTy.Tensor, BaseTy.Generator, + BaseTy.Storage, BaseTy.Stream, BaseTy.str]: + # These python schema type names line up with their function schema names + ret = t.name.name + + elif isinstance(t, ListType): + if pyi_out_arg and t.is_tensor_like(): + # TODO remove HACK + # pyi blindly treats all tensor-like out args as having type Tensor + return 'Tensor' + if str(t.elem) == 'int': + ret = 'Union[_int, _size]' if t.size is not None else '_size' + elif t.is_tensor_like(): + # TODO: this doesn't seem right... + # Tensor?[] currently translates to Optional[Union[Tuple[Tensor, ...], List[Tensor]]] + # It should probably translate to Union[Tuple[Optional[Tensor], ...], List[Optional[Tensor]]] + if isinstance(t.elem, OptionalType): + add_optional = True + ret = 'Union[Tensor, Tuple[Tensor, ...], List[Tensor]]' if t.size is not None else \ + 'Union[Tuple[Tensor, ...], List[Tensor]]' + elif str(t.elem) == 'float': + ret = 'Sequence[float]' + else: + elem = argument_type_str_pyi(t.elem) + ret = f'Sequence[{elem}]' + + if add_optional: + ret = 'Optional[' + ret + ']' + return ret + + raise RuntimeError(f'unrecognized type {repr(t)}') + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # # C++ Function Dispatch From b94ec8c9f71b461d849b95f07c9ce8c31a366bbf Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Fri, 11 Dec 2020 13:24:55 -0800 Subject: [PATCH 14/33] pyi codegen - removing byte-for-byte compatibility hacks (#49055) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49055 Removed the majority of the TODO hacks that I added to the original pyi PR to maintain byte-for-byte compatibility. I left a few of the divergences between pyi deprecated vs. native signatures, since (a) they're smaller and (b) it might make more sense to kill the deprecated functions at some point entirely. Test Plan: Imported from OSS Reviewed By: ljk53 Differential Revision: D25410847 Pulled By: bdhirsh fbshipit-source-id: cf07cdda92f7492cd83d363cbb810e3810f6b8c8 --- tools/codegen/api/python.py | 68 +++++++++++------------------- tools/pyi/gen_pyi.py | 12 +----- torch/_C/_VariableFunctions.pyi.in | 2 +- 3 files changed, 27 insertions(+), 55 deletions(-) diff --git a/tools/codegen/api/python.py b/tools/codegen/api/python.py index dadfed354106..10483e2e3d76 100644 --- a/tools/codegen/api/python.py +++ b/tools/codegen/api/python.py @@ -249,8 +249,7 @@ def argument_str(self, *, method: bool = False) -> str: return f'{type_str} {name}' def argument_str_pyi(self, *, method: bool = False, deprecated: bool = False) -> str: - is_out_arg = isinstance(self, PythonOutArgument) - type_str = argument_type_str_pyi(self.type, pyi_out_arg=is_out_arg) + type_str = argument_type_str_pyi(self.type) name = self.name # s/self/input/ outside method bindings @@ -261,12 +260,13 @@ def argument_str_pyi(self, *, method: bool = False, deprecated: bool = False) -> if name == 'from': # from is a Python keyword... name += '_' + # pyi merges the _out and functional variants into the same signature, with an optional out arg if name == 'out' and type_str == 'Tensor' and not deprecated: type_str = 'Optional[' + type_str + ']' - # TODO: remove diff. pyi deprecated signatures don't get defaults for their out arg - treat_as_no_default = deprecated and is_out_arg and self.default == 'None' + # pyi deprecated signatures don't get defaults for their out arg + treat_as_no_default = deprecated and isinstance(self, PythonOutArgument) and self.default == 'None' # add default if self.default is not None and not treat_as_no_default: @@ -281,11 +281,9 @@ def argument_str_pyi(self, *, method: bool = False, deprecated: bool = False) -> 'MemoryFormat::Contiguous': 'contiguous_format', 'QScheme::PER_TENSOR_AFFINE': 'per_tensor_affine', }.get(self.default, self.default) - # TODO: remove requires_grad special case (byte-for-byte compat) - return f'{name}:{type_str}={default}' if name == 'requires_grad' else f'{name}: {type_str}={default}' + return f'{name}: {type_str}={default}' else: - # TODO: remove requires_grad special case (byte-for-byte compat) - return f'{name}:{type_str}' if name == 'requires_grad' else f'{name}: {type_str}' + return f'{name}: {type_str}' @dataclass(frozen=True) class PythonOutArgument(PythonArgument): @@ -357,23 +355,13 @@ def deprecated(self) -> bool: return False def arguments( - self, *, skip_outputs: bool = False, skip_tensor_options: bool = False, hacky_add_output: bool = False + self, *, skip_outputs: bool = False, skip_tensor_options: bool = False ) -> Tuple[Union[PythonArgument, PythonOutArgument], ...]: result: List[Union[PythonArgument, PythonOutArgument]] = [] result.extend(self.input_args) result.extend(self.input_kwargs) if self.output_args is not None and not skip_outputs: result.append(self.output_args) - # TODO: remove HACK - # in the existing pyi codegen, we tack on an optional out argument to every operator overload - # if there exists at least one overload with an out variant. This seems wrong. - elif hacky_add_output: - result.extend([PythonOutArgument( - name='out', - type=OptionalType(BaseType(BaseTy.Tensor)), - default='None', - default_init=None, - outputs=())]) if not skip_tensor_options: result.extend(self.tensor_options_args) return tuple(result) @@ -402,8 +390,8 @@ def signature_str(self, *, skip_outputs: bool = False) -> str: return f'{self.name}({", ".join(schema_formals)})' - def signature_str_pyi(self, *, skip_outputs: bool = False, hacky_add_output: bool = False) -> str: - args = self.arguments(skip_outputs=skip_outputs, hacky_add_output=hacky_add_output) + def signature_str_pyi(self, *, skip_outputs: bool = False) -> str: + args = self.arguments(skip_outputs=skip_outputs) schema_formals: List[str] = list(map(lambda a: a.argument_str_pyi(method=self.method), args)) positional_argc = len(self.input_args) if len(schema_formals) > positional_argc: @@ -416,9 +404,9 @@ def signature_str_pyi(self, *, skip_outputs: bool = False, hacky_add_output: boo schema_formals.insert(0, "self") return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...' - def signature_str_pyi_vararg(self, *, skip_outputs: bool = False, hacky_add_output: bool = False) -> Optional[str]: + def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> Optional[str]: # only pyi uses vararg signatures - args = self.arguments(skip_outputs=skip_outputs, hacky_add_output=hacky_add_output) + args = self.arguments(skip_outputs=skip_outputs) schema_formals: List[str] = list(map(lambda a: a.argument_str_pyi(method=self.method), args)) # vararg only applies to pyi signatures. vararg variants are not generated for all signatures num_args = self.arguments_count() @@ -469,8 +457,8 @@ def deprecated(self) -> bool: def signature_str(self, *, skip_outputs: bool = False) -> str: return PythonSignature.signature_str(self, skip_outputs=skip_outputs) + '|deprecated' - def signature_str_pyi(self, *, skip_outputs: bool = False, hacky_add_output: bool = False) -> str: - args = self.arguments(skip_outputs=skip_outputs, hacky_add_output=hacky_add_output) + def signature_str_pyi(self, *, skip_outputs: bool = False) -> str: + args = self.arguments(skip_outputs=skip_outputs) schema_formals: List[str] = list(map(lambda a: a.argument_str_pyi(method=self.method, deprecated=True), args)) positional_argc = len(self.input_args) if len(schema_formals) > positional_argc: @@ -479,7 +467,7 @@ def signature_str_pyi(self, *, skip_outputs: bool = False, hacky_add_output: boo returns_str = self.returns.returns_str_pyi() return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...' - def signature_str_pyi_vararg(self, *, skip_outputs: bool = False, hacky_add_output: bool = False) -> Optional[str]: + def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> Optional[str]: # the codegen doesn't include vararg variants for deprecated signatures return None @@ -712,11 +700,9 @@ def signature(f: NativeFunction, *, method: bool = False, pyi: bool = False) -> default='None' if pyi else _dtype_default_type_hack(name), default_init='self.scalar_type()' if is_like_or_new_function else None, )) - # TODO: probably a bug, kill this diff? - # pyi signatures have a slightly different type/default for layout tensor_options_args.append(PythonArgument( name='layout', - type=BaseType(BaseTy.Layout) if pyi else OptionalType(BaseType(BaseTy.Layout)), + type=OptionalType(BaseType(BaseTy.Layout)), default='strided' if pyi else 'torch.strided', default_init='layout_from_backend(self.options().backend())' if is_like_or_new_function else None, )) @@ -726,15 +712,12 @@ def signature(f: NativeFunction, *, method: bool = False, pyi: bool = False) -> default='None', default_init='self.device()' if is_like_or_new_function else None, )) - # TODO: probably a bug, kill this diff? - # pyi signatures don't include pin memory - if not pyi: - tensor_options_args.append(PythonArgument( - name='pin_memory', - type=BaseType(BaseTy.bool), - default='False', - default_init=None, - )) + tensor_options_args.append(PythonArgument( + name='pin_memory', + type=BaseType(BaseTy.bool), + default='False', + default_init=None, + )) tensor_options_args.append(PythonArgument( name='requires_grad', type=BaseType(BaseTy.bool), @@ -755,12 +738,13 @@ def signature(f: NativeFunction, *, method: bool = False, pyi: bool = False) -> ) # TODO blowtorch +# note: removing this will be BC-breaking. A quick test shows that +# randperm will otherwise default its dtype to torch.float64 def _dtype_default_type_hack(name: str) -> str: if name.startswith('randperm') or name == 'tril_indices' or name == 'triu_indices': return 'torch.int64' else: return 'None' - # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # # Python Interface @@ -785,7 +769,7 @@ def namedtuple_fieldnames(returns: Tuple[Return, ...]) -> List[str]: return list(map(lambda r: str(r.name), returns)) -def argument_type_str_pyi(t: Type, *, pyi_out_arg: bool = False) -> str: +def argument_type_str_pyi(t: Type) -> str: add_optional = False if isinstance(t, OptionalType): t = t.elem @@ -820,10 +804,6 @@ def argument_type_str_pyi(t: Type, *, pyi_out_arg: bool = False) -> str: ret = t.name.name elif isinstance(t, ListType): - if pyi_out_arg and t.is_tensor_like(): - # TODO remove HACK - # pyi blindly treats all tensor-like out args as having type Tensor - return 'Tensor' if str(t.elem) == 'int': ret = 'Union[_int, _size]' if t.size is not None else '_size' elif t.is_tensor_like(): diff --git a/tools/pyi/gen_pyi.py b/tools/pyi/gen_pyi.py index ee5c38a4cf1c..9a3a0f520e54 100644 --- a/tools/pyi/gen_pyi.py +++ b/tools/pyi/gen_pyi.py @@ -213,23 +213,15 @@ def generate_type_hints(funcs: Sequence[PythonSignatureGroup], is_tensor: bool = type_hint = sig_group.signature.signature_str_pyi(skip_outputs=True) type_hints.append(type_hint) - # TODO: remove HACK - # the pyi codegen currently adds an optional out param in cases where the current op does NOT have an out variant, - # but an overload of the op DOES have an out variant. - # TODO: After that, we should consider killing this method entirely and operating per PythonSignatureGroup - # rather than grouping their overloads together - # (since there isn't much else semantically meaningful about grouping overloads) - # this hack also doesn't apply to deprecated ops - hacky_add_output = any_out and sig_group.outplace is None and not sig_group.signature.deprecated # PythonSignatureGroups that have both a functional + out variant get a single signature, with an optional out argument # Generates the out variant if one exists. Otherwise, generate the functional variant type_hint = sig_group.signature.signature_str_pyi( - skip_outputs=sig_group.outplace is None, hacky_add_output=hacky_add_output) + skip_outputs=sig_group.outplace is None) type_hints.append(type_hint) # Some operators also additionally have a vararg variant of their signature type_hint_vararg = sig_group.signature.signature_str_pyi_vararg( - skip_outputs=sig_group.outplace is None, hacky_add_output=hacky_add_output) + skip_outputs=sig_group.outplace is None) if type_hint_vararg: type_hints.append(type_hint_vararg) diff --git a/torch/_C/_VariableFunctions.pyi.in b/torch/_C/_VariableFunctions.pyi.in index 1360ef079725..1afd8e6c73d7 100644 --- a/torch/_C/_VariableFunctions.pyi.in +++ b/torch/_C/_VariableFunctions.pyi.in @@ -1,6 +1,6 @@ # ${generated_comment} -from torch import Tensor, Generator, strided, memory_format, contiguous_format +from torch import Tensor, Generator, strided, memory_format, contiguous_format, strided from typing import List, Tuple, Optional, Union, Any, ContextManager, Callable, overload, Iterator, NamedTuple, Sequence, TypeVar from torch._six import inf From 33a9b14da04bfd990fd454ce3dc6eaa1668f2159 Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Fri, 11 Dec 2020 13:24:55 -0800 Subject: [PATCH 15/33] pyi codegen - removing byte-for-byte-compatibility hacks (sorting overloads) (#49056) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49056 This is another byte-for-byte compatibility hack. I'm now sorting pyi signature overloads (previously the codegen did not). Mostly put this in a separate PR just to more easily reason about the diff in the codegen output. Test Plan: Imported from OSS Reviewed By: ljk53 Differential Revision: D25410846 Pulled By: bdhirsh fbshipit-source-id: 06e5c32edbce610dd12ec7499014b41b23c646bd --- tools/autograd/gen_python_functions.py | 6 +----- tools/pyi/gen_pyi.py | 2 +- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/tools/autograd/gen_python_functions.py b/tools/autograd/gen_python_functions.py index 47abce5466c6..570c99908853 100644 --- a/tools/autograd/gen_python_functions.py +++ b/tools/autograd/gen_python_functions.py @@ -637,8 +637,6 @@ def method_def( def group_overloads( overloads: Sequence[PythonSignatureNativeFunctionPair], - *, - sort: bool = True, ) -> Sequence[PythonSignatureGroup]: bases: Dict[str, PythonSignatureNativeFunctionPair] = {} outplaces: Dict[str, PythonSignatureNativeFunctionPair] = {} @@ -687,9 +685,7 @@ def group_overloads( outplace=outplace.function if outplace is not None else None, )) - # TODO: unconditionally sort - # maintaining byte-for-byte compatibility for pyi codegen for now - return grouped if not sort else sort_overloads(grouped) + return sort_overloads(grouped) # This function declares a partial order on declarations, and sorts them according # to its linear extension. This is necessary, because there's some ambiguity in the diff --git a/tools/pyi/gen_pyi.py b/tools/pyi/gen_pyi.py index 9a3a0f520e54..21f965cb101b 100644 --- a/tools/pyi/gen_pyi.py +++ b/tools/pyi/gen_pyi.py @@ -39,7 +39,7 @@ # TODO: consider waiting to group by base name until we actually need to # (after computing type hint signatures, when adding @overload directives) def group_by_base_name(python_funcs: Sequence[PythonSignatureNativeFunctionPair]) -> Mapping[str, List[PythonSignatureGroup]]: - groups = group_overloads(python_funcs, sort=False) + groups = group_overloads(python_funcs) d = collections.defaultdict(list) for g in groups: name = g.signature.name From 218eaf4bbafef23600e6c9e668b7a49633639734 Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Fri, 11 Dec 2020 13:24:55 -0800 Subject: [PATCH 16/33] pyi codegen refactor - no need to group python signatures by overload name (#49057) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49057 Now that all of the byte-for-byte hacks are removed in the pyi codegen, there's no reason for the codegen to group pyi signature overloads together. I updated the logic in `gen_pyi` that computes signatures (`generate_type_hints()` and _generate_named_tuples()`) to operate per individual `PythonSignatureGroup` Test Plan: Imported from OSS Reviewed By: ezyang Differential Revision: D25410849 Pulled By: bdhirsh fbshipit-source-id: 8c190035d7bfc06ed192468efbe7d902922ad1fa --- tools/pyi/gen_pyi.py | 113 ++++++++++++++++++------------------------- 1 file changed, 48 insertions(+), 65 deletions(-) diff --git a/tools/pyi/gen_pyi.py b/tools/pyi/gen_pyi.py index 21f965cb101b..dad150fa0ad5 100644 --- a/tools/pyi/gen_pyi.py +++ b/tools/pyi/gen_pyi.py @@ -7,7 +7,7 @@ from tools.codegen.model import * from tools.codegen.api.python import * -from typing import Sequence, List, Mapping, Dict +from typing import Sequence, List, Dict from ..autograd.utils import CodeTemplate, write from ..autograd.gen_python_functions import should_generate_py_binding, load_signatures, group_overloads @@ -36,20 +36,10 @@ read gen_pyi for the gory details. """ -# TODO: consider waiting to group by base name until we actually need to -# (after computing type hint signatures, when adding @overload directives) -def group_by_base_name(python_funcs: Sequence[PythonSignatureNativeFunctionPair]) -> Mapping[str, List[PythonSignatureGroup]]: - groups = group_overloads(python_funcs) - d = collections.defaultdict(list) - for g in groups: - name = g.signature.name - d[name].append(g) - return d - def get_py_torch_functions( python_funcs: Sequence[PythonSignatureNativeFunctionPair], method: bool = False, -) -> Mapping[str, Sequence[PythonSignatureGroup]]: +) -> Sequence[PythonSignatureGroup]: """ Get declarations (grouped by name) which should be generated as either functions in the "torch" module or methods on Tensor. @@ -65,7 +55,7 @@ def should_bind_method(python_func: PythonSignatureNativeFunctionPair) -> bool: Variant.method in python_func.function.variants) should_bind = should_bind_method if method else should_bind_function - return group_by_base_name([f for f in python_funcs if should_bind(f)]) + return group_overloads([f for f in python_funcs if should_bind(f)]) # TODO: Consider defining some aliases for our Union[...] types, to make @@ -176,54 +166,31 @@ def sig_for_ops(opname: str) -> List[str]: else: raise Exception("unknown op", opname) -def generate_named_tuples(funcs: Sequence[PythonSignatureGroup]) -> Dict[str, str]: - namedtuples: Dict[str, str] = {} - for sig_group in funcs: - named_tuple = sig_group.signature.returns.named_tuple_pyi() - if named_tuple is not None: - tuple_name, tuple_def = named_tuple - if tuple_name in namedtuples: - assert namedtuples[tuple_name] == tuple_def - else: - namedtuples[tuple_name] = tuple_def - return namedtuples - -def generate_type_hints(funcs: Sequence[PythonSignatureGroup], is_tensor: bool = False) -> List[str]: - """generate_type_hints(funcs, is_tensor=False) +def generate_type_hints(sig_group: PythonSignatureGroup) -> List[str]: + type_hints = [] - Generates type hints for the declarations pertaining to the function - :attr:`funcs` are the func from the parsed native_functions.yaml. - The :attr:`is_tensor` flag indicates whether we are parsing - members of the Tensor class (true) or functions in the - `torch` namespace (default, false). - """ + # Some deprecated ops that are on the blocklist are still included in pyi + if sig_group.signature.name in blocklist and not sig_group.signature.deprecated: + return type_hints - type_hints = [] - any_out = any([g for g in funcs if g.outplace is not None]) - - for sig_group in funcs: - # Some deprecated ops that are on the blocklist are still included in pyi - if sig_group.signature.name in blocklist and not sig_group.signature.deprecated: - continue - - # deprecated signatures have separate entries for their functional and out variants - # (as opposed to the native ops, which fuse the two into a single signature). - # generate the functional variant here, if an out variant exists. - if sig_group.signature.deprecated and sig_group.outplace is not None: - type_hint = sig_group.signature.signature_str_pyi(skip_outputs=True) - type_hints.append(type_hint) - - # PythonSignatureGroups that have both a functional + out variant get a single signature, with an optional out argument - # Generates the out variant if one exists. Otherwise, generate the functional variant - type_hint = sig_group.signature.signature_str_pyi( - skip_outputs=sig_group.outplace is None) + # deprecated signatures have separate entries for their functional and out variants + # (as opposed to the native ops, which fuse the two into a single signature). + # generate the functional variant here, if an out variant exists. + if sig_group.signature.deprecated and sig_group.outplace is not None: + type_hint = sig_group.signature.signature_str_pyi(skip_outputs=True) type_hints.append(type_hint) - # Some operators also additionally have a vararg variant of their signature - type_hint_vararg = sig_group.signature.signature_str_pyi_vararg( - skip_outputs=sig_group.outplace is None) - if type_hint_vararg: - type_hints.append(type_hint_vararg) + # PythonSignatureGroups that have both a functional + out variant get a single signature, with an optional out argument + # Generates the out variant if one exists. Otherwise, generate the functional variant + type_hint = sig_group.signature.signature_str_pyi( + skip_outputs=sig_group.outplace is None) + type_hints.append(type_hint) + + # Some operators also additionally have a vararg variant of their signature + type_hint_vararg = sig_group.signature.signature_str_pyi_vararg( + skip_outputs=sig_group.outplace is None) + if type_hint_vararg: + type_hints.append(type_hint_vararg) return type_hints @@ -376,11 +343,18 @@ def gen_pyi(native_yaml_path: str, deprecated_yaml_path: str, out: str) -> None: function_signatures = load_signatures(native_yaml_path, deprecated_yaml_path, method=False, pyi=True) sig_groups = get_py_torch_functions(function_signatures) - for name in sorted(sig_groups.keys()): - unsorted_function_hints[name] += generate_type_hints(sig_groups[name]) - # deprecated signatures are not used when computing named tuples - native_groups = [g for g in sig_groups[name] if not g.signature.deprecated] - namedtuples.update(generate_named_tuples(native_groups)) + for group in sorted(sig_groups, key=lambda g: g.signature.name): + name = group.signature.name + unsorted_function_hints[name] += generate_type_hints(group) + + named_tuple = group.signature.returns.named_tuple_pyi() + if named_tuple is not None and not group.signature.deprecated: + # deprecated namedtuples are currently not included for torch functions + tuple_name, tuple_def = named_tuple + if tuple_name in namedtuples: + assert namedtuples[tuple_name] == tuple_def + else: + namedtuples[tuple_name] = tuple_def function_hints = [] for name, hints in sorted(unsorted_function_hints.items()): @@ -490,9 +464,18 @@ def gen_pyi(native_yaml_path: str, deprecated_yaml_path: str, out: str) -> None: tensor_method_signatures = load_signatures(native_yaml_path, deprecated_yaml_path, method=True, skip_deprecated=True, pyi=True) tensor_method_sig_groups = get_py_torch_functions(tensor_method_signatures, method=True) - for name in sorted(tensor_method_sig_groups.keys()): - unsorted_tensor_method_hints[name] += generate_type_hints(tensor_method_sig_groups[name], is_tensor=True) - namedtuples.update(generate_named_tuples(tensor_method_sig_groups[name])) + for group in sorted(tensor_method_sig_groups, key=lambda g: g.signature.name): + name = group.signature.name + unsorted_tensor_method_hints[name] += generate_type_hints(group) + + named_tuple = group.signature.returns.named_tuple_pyi() + if named_tuple is not None and not group.signature.deprecated: + # deprecated namedtuples are currently not included for torch functions + tuple_name, tuple_def = named_tuple + if tuple_name in namedtuples: + assert namedtuples[tuple_name] == tuple_def + else: + namedtuples[tuple_name] = tuple_def for op in all_ops: name = '__{}__'.format(op) From 15200e385a764721000f1dfadbcaf42c328bafdd Mon Sep 17 00:00:00 2001 From: kiyosora Date: Fri, 11 Dec 2020 13:35:14 -0800 Subject: [PATCH 17/33] Enable torch.where() to support Float16 & BFloat16 type inputs (#49004) Summary: Fixed https://github.com/pytorch/pytorch/issues/49075 Pull Request resolved: https://github.com/pytorch/pytorch/pull/49004 Reviewed By: zou3519 Differential Revision: D25495225 Pulled By: H-Huang fbshipit-source-id: 09418ee5503f65c8862e40119c5802779505a4db --- aten/src/ATen/native/cpu/TensorCompareKernel.cpp | 3 ++- aten/src/ATen/native/cuda/TensorCompare.cu | 2 +- test/test_torch.py | 8 ++++---- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/aten/src/ATen/native/cpu/TensorCompareKernel.cpp b/aten/src/ATen/native/cpu/TensorCompareKernel.cpp index b9653c7b25bf..b407eac4d280 100644 --- a/aten/src/ATen/native/cpu/TensorCompareKernel.cpp +++ b/aten/src/ATen/native/cpu/TensorCompareKernel.cpp @@ -183,7 +183,8 @@ static void _aminmax_kernel_impl( } static void where_kernel_impl(TensorIterator &iter, ScalarType condition_type) { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(at::ScalarType::Bool, iter.dtype(), "where_cpu", [&] { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, + iter.dtype(), "where_cpu", [&] { if (condition_type == at::ScalarType::Byte) { cpu_kernel( iter, diff --git a/aten/src/ATen/native/cuda/TensorCompare.cu b/aten/src/ATen/native/cuda/TensorCompare.cu index 443bea3f71ac..b10ae52e44fd 100644 --- a/aten/src/ATen/native/cuda/TensorCompare.cu +++ b/aten/src/ATen/native/cuda/TensorCompare.cu @@ -17,7 +17,7 @@ DECLARE_DISPATCH(is_infinity_op_fn, isneginf_stub); namespace { void where_kernel_impl(TensorIterator &iter, ScalarType condition_type) { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBool, iter.dtype(), "where_cuda", [&] { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kHalf, kBFloat16, kBool, iter.dtype(), "where_cuda", [&] { if (condition_type == at::ScalarType::Byte) { gpu_kernel( iter, diff --git a/test/test_torch.py b/test/test_torch.py index 16e011645899..d2566a90f382 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -254,8 +254,8 @@ def get_tensor(size, dtype, device, contiguous): height = 5 width = 5 for device in torch.testing.get_all_device_types(): - for dt1 in torch.testing.get_all_dtypes(include_half=device.startswith('cuda'), include_bfloat16=False): - for dt2 in torch.testing.get_all_dtypes(include_half=device.startswith('cuda'), include_bfloat16=False): + for dt1 in torch.testing.get_all_dtypes(): + for dt2 in torch.testing.get_all_dtypes(): for contiguous in [True, False]: x1 = get_tensor((height, width), dt1, device, contiguous) x2 = get_tensor((height, width), dt2, device, contiguous) @@ -6174,7 +6174,7 @@ def _where_valid_scalar_tensor_combination(self, scalar_type, dtype): return False @onlyOnCPUAndCUDA - @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes(include_bfloat16=False) + + @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes() + torch.testing.get_all_complex_dtypes())) def test_where_scalar_invalid_combination_raises(self, device, dtype): @@ -6186,7 +6186,7 @@ def checkRaises(scalar_type, dtype, condition, x, scalar_1): self._test_where_scalar_template(device, dtype, checkRaises) - @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes(include_bfloat16=False) + + @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes() + torch.testing.get_all_complex_dtypes())) def test_where_scalar_valid_combination(self, device, dtype): From 4bc4ec2686b69166f8784ee6d4ba2d1c9e582968 Mon Sep 17 00:00:00 2001 From: Ilia Cherniavskii Date: Fri, 11 Dec 2020 13:48:32 -0800 Subject: [PATCH 18/33] Reduce kineto logging (#49216) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49216 Libkineto is pretty verbose by default, using libkineto api to reduce amount of logging Test Plan: TORCH_CUDA_ARCH_LIST="6.0;7.0" USE_CUDA=1 USE_MKLDNN=1 BUILD_BINARY=1 python setup.py develop install --cmake python test/test_profiler.py Imported from OSS Reviewed By: ngimel Differential Revision: D25488109 fbshipit-source-id: 61b443bcf928db939f730ba32711385bb2b622d4 --- torch/csrc/autograd/profiler_kineto.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/torch/csrc/autograd/profiler_kineto.cpp b/torch/csrc/autograd/profiler_kineto.cpp index 7c91e76490a1..ac6ef84104f3 100644 --- a/torch/csrc/autograd/profiler_kineto.cpp +++ b/torch/csrc/autograd/profiler_kineto.cpp @@ -242,6 +242,7 @@ void prepareProfiler( if (!libkineto::api().isProfilerRegistered()) { libkineto_init(); + libkineto::api().suppressLogMessages(); } if (!libkineto::api().isProfilerInitialized()) { From e3542d2c12d8aaaccf8a53873e480c20dc6b7338 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Fri, 11 Dec 2020 13:55:01 -0800 Subject: [PATCH 19/33] [PyTorch] avoid unnecessary call to empty_tensor_restride in empty() (#48211) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/48211 Our empty benchmark makes this call unconditionally. If MemoryFormat::Contiguous is indeed a common case (or if workloads are likely to use a consistent-ish memory format), then I'd expect checking first to be a win. ghstack-source-id: 118224990 Test Plan: Profiled empty benchmark with perf, saw time spent in empty_tensor_restride go down. Ran framework overhead benchmarks. ~7% win on empty(), 0.5-1.5% regression on InPlace, ~2% win on OutOfPlace. Seems like both the In/Out of place ones are likely to be noise because they don't exercise empty? Reviewed By: bhosmer Differential Revision: D24914706 fbshipit-source-id: 916771b335143f9b4ec9fae0d8118222ab6e8659 --- aten/src/ATen/Utils.cpp | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/Utils.cpp b/aten/src/ATen/Utils.cpp index a2e5a82c5d06..26fc7dabfd73 100644 --- a/aten/src/ATen/Utils.cpp +++ b/aten/src/ATen/Utils.cpp @@ -57,8 +57,12 @@ Tensor empty_cpu( tensor.unsafeGetTensorImpl()->set_sizes_contiguous(size); } - auto memory_format = memory_format_opt.value_or(MemoryFormat::Contiguous); - tensor.unsafeGetTensorImpl()->empty_tensor_restride(memory_format); + if (memory_format_opt.has_value()) { + // Restriding a just-created empty contiguous tensor does nothing. + if (*memory_format_opt != MemoryFormat::Contiguous) { + tensor.unsafeGetTensorImpl()->empty_tensor_restride(*memory_format_opt); + } + } return tensor; } From 6c1b405a3bc5392081569f1530d74e8459e0e211 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Fri, 11 Dec 2020 14:13:05 -0800 Subject: [PATCH 20/33] Updated derivative rules for complex QR decomposition (#48489) Summary: Updated `qr_backward` to work correctly for complex-valued inputs. Added `torch.qr` to list of complex tests. The previous implementation for real-valued differentiation used equation 42 from https://arxiv.org/abs/1001.1654 The current implementation is a bit simpler but the result for the real-valued input case is the same and all tests still pass. Derivation of complex-valued QR differentiation https://giggleliu.github.io/2019/04/02/einsumbp.html Ref. https://github.com/pytorch/pytorch/issues/33152 Pull Request resolved: https://github.com/pytorch/pytorch/pull/48489 Reviewed By: bdhirsh Differential Revision: D25272344 Pulled By: albanD fbshipit-source-id: b53c1fca1683f4aee5f4d5ce3cab9e559170e7cf --- test/test_autograd.py | 2 +- tools/autograd/gen_variable_type.py | 2 +- torch/csrc/autograd/FunctionsManual.cpp | 79 +++++++++++-------------- 3 files changed, 37 insertions(+), 46 deletions(-) diff --git a/test/test_autograd.py b/test/test_autograd.py index 796860cf639f..0d99169f4d65 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -4927,7 +4927,7 @@ def run_functional_checks(test_case, test_name, name, apply_fn, run_grad_checks, 'cosh', '__rmul__', 'sgn', 'abs', 'dot', 'vdot', 'tensor_split', 'matmul', 'bmm', 'mv', 'ger', 'diagonal', 'atan', 'angle', 'tanh', 'fill_', 'sub', 'exp', 'mean', 'inverse', 'triangular_solve', 'solve', 'addcmul', - 'addcdiv', 'linalg.tensorinv', 'matrix_exp'] + separate_complex_tests + 'addcdiv', 'linalg.tensorinv', 'matrix_exp', 'qr', ] + separate_complex_tests def add_test( name, diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index 123b180f1774..a17e222f8cf1 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -78,7 +78,7 @@ 'bmm', 'diagonal', 'alias', 'atan', 'log', 'log10', 'log1p', 'log2', 'reciprocal', 'tan', 'pow', 'rsqrt', 'tanh', 'tanh_backward', 'asinh', 'acosh', 'take', 'fill_', 'exp', 'nonzero', 'mean', 'inverse', 'solve', 'linalg_cholesky', 'addcmul', 'addcdiv', - 'matrix_exp', 'linalg_eigh', 'cholesky_solve', + 'matrix_exp', 'linalg_eigh', 'cholesky_solve', 'qr', '_fft_c2c', '_fft_r2c', } diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index 4d71d6759e0c..6da1a7e5e934 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -2006,67 +2006,58 @@ Tensor qr_backward(const std::vector &grads, const Te const Tensor& A, const Tensor& Q, const Tensor& R) -> Tensor { - // For square and deep (tall) case we refer - // Walter, S.F and Lehmann, L., Algorithmic Differentiation of Linear - // Algebra Functions with Application in Optimum Experimental Design - // (Extended Version) The derivative for the QR decomposition is adapted - // from Eq. 42 of the above reference. - - // Compute R (R')^{T} + // For square and deep (tall) case we refer: + // Matthias Seeger, Asmus Hetzel, Zhenwen Dai, Eric Meissner, Neil D. Lawrence (2018). Auto-Differentiating Linear Algebra. + // https://arxiv.org/abs/1710.08717 Section 4.3 LQ Decomposition (Note that LQ decomposition is the transpose of QR decomposition) + // Hai-Jun Liao, Jin-Guo Liu, Lei Wang, Tao Xiang (2019). Differentiable Programming Tensor Networks. + // https://arxiv.org/abs/1903.09650 Section 3. QR factorization + // For derivations of complex-valued input case, see https://giggleliu.github.io/2019/04/02/einsumbp.html + + // Compute R grad_R^H Tensor R_term; if (grad_R.defined()) { - R_term = at::matmul(R, grad_R.transpose(-2, -1)); + R_term = at::matmul(R, grad_R.conj().transpose(-2, -1)); } else { // R is ... x N x N, grad_R is ... x N x N and grad_R.T is ... x N x N R_term = at::zeros_like(R, LEGACY_CONTIGUOUS_MEMORY_FORMAT); } - // Compute Q^{T} Q' + // Compute grad_Q^H Q Tensor Q_term; if (grad_Q.defined()) { - Q_term = at::matmul(Q.transpose(-2, -1), grad_Q); + Q_term = at::matmul(grad_Q.conj().transpose(-2, -1), Q); } else { // Q is ... x M x N, Q.T is ... x N x M and grad_Q is ... x M x N Q_term = at::zeros_like(R, LEGACY_CONTIGUOUS_MEMORY_FORMAT); } - // We want to compute: (rhs_solve_1 . R^{-T}) - // Note that (rhs_solve_1 . R^{-T}) = (R^{-1} . rhs_solve_1^{T})^{T} + Tensor M = R_term - Q_term; + + // Compute M = (tril(M) + tril(M).conj().transpose(-2, -1)) * 0.5 Identity + Tensor M_tril = at::tril(M); + M = M_tril + M_tril.conj().transpose(-2, -1); + M.diagonal(0, -2, -1).mul_(0.5); + + Tensor rhs_term; + if (grad_Q.defined()) { + rhs_term = grad_Q + at::matmul(Q, M); + } else { + rhs_term = at::matmul(Q, M); + } + + // We want to compute: (rhs_term @ R^{-H}) + // Note that (rhs_term @ R^{-H}) = (R^{-1} @ rhs_solve_1^H)^H // Since R is upper triangular, we can do this using - // triangular_solve(rhs_solve_1^{T}, R)^{T} - auto rhs_solve_1 = - R_term - R_term.transpose(-2, -1) + Q_term - Q_term.transpose(-2, -1); - rhs_solve_1 = at::tril(rhs_solve_1, /*k=*/-1); - Tensor solve_soln_1; - std::tie(solve_soln_1, std::ignore) = at::triangular_solve( - rhs_solve_1.transpose(-2, -1), + // triangular_solve(rhs_term^H, R)^H + Tensor grad_A; + std::tie(grad_A, std::ignore) = at::triangular_solve( + rhs_term.conj().transpose(-2, -1), R, /*upper=*/true, /*transpose=*/false, /*unitriangular=*/false); - Tensor grad_A; - if (grad_R.defined()) { - grad_A = at::matmul(Q, solve_soln_1.transpose(-2, -1) + grad_R); - } else { - grad_A = at::matmul(Q, solve_soln_1.transpose(-2, -1)); - } - // Successive computations involve computation of QQ^{T} which is identity when A is square - if (A.size(-1) != A.size(-2)) { - Tensor rhs_solve_2; - // We use the same trick from above for this computation - if (grad_Q.defined()) { - rhs_solve_2 = grad_Q - at::matmul(Q, Q_term); - } else { - rhs_solve_2 = -at::matmul(Q, Q_term); - } - Tensor solve_soln_2; - std::tie(solve_soln_2, std::ignore) = at::triangular_solve(rhs_solve_2.transpose(-2, -1), R, - /*upper=*/true, /*transpose=*/false, - /*unitriangular=*/false); - grad_A.add_(solve_soln_2.transpose(-2, -1)); - } - return grad_A; + return grad_A.conj().transpose(-2, -1); }; auto m = self.size(-2); @@ -2087,7 +2078,7 @@ Tensor qr_backward(const std::vector &grads, const Te // grad_R = [grad_U | grad_V] and grad_A = [grad_X | grad_Y]. // To obtain grad_X we reuse the gradient formula from the square case. // Formulae: grad_X = square_case_grad(grad_Q_prime, grad_U, Q, U), - // where grad_Q_prime = grad_Q + Y @ grad_V.T + // where grad_Q_prime = grad_Q + Y @ grad_V^H // and grad_Y = Q @ grad_V. // Then concatenate grads to get grad_A = [grad_X | grad_Y]. @@ -2099,8 +2090,8 @@ Tensor qr_backward(const std::vector &grads, const Te grad_V = grad_R.narrow(-1, m, n - m); // reuse grad_R to store grad_U grad_R = grad_R.narrow(-1, 0, m); - // grad_Q_prime starts with the value of Y @ grad_V.T - grad_Q_prime = at::matmul(Y, grad_V.transpose(-2, -1)); + // grad_Q_prime starts with the value of Y @ grad_V^H + grad_Q_prime = at::matmul(Y, grad_V.conj().transpose(-2, -1)); } else { // when grad_R is not defined then grad_V and grad_Q_prime // get initialized with zeros From c6147ae4c99b13b1fbc8fb1b36deaca941bfd1c6 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Fri, 11 Dec 2020 14:16:22 -0800 Subject: [PATCH 21/33] [PyTorch] Fix getCustomClassType() perf (#48981) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/48981 1) It was copying the entire hash table every time. 2) We don't need to do a hash lookup at all. ghstack-source-id: 118164406 Reviewed By: dzhulgakov Differential Revision: D25385543 fbshipit-source-id: 6be95c742d6713345c51859ce36a7791a9e2e3f0 --- aten/src/ATen/core/ivalue.cpp | 2 +- aten/src/ATen/core/ivalue.h | 14 +++++++++----- aten/src/ATen/core/ivalue_inl.h | 27 +++++++++++++++------------ aten/src/ATen/core/jit_type.h | 19 ++++++++++++------- 4 files changed, 37 insertions(+), 25 deletions(-) diff --git a/aten/src/ATen/core/ivalue.cpp b/aten/src/ATen/core/ivalue.cpp index 6b8f4412cbf7..60382e37b6ff 100644 --- a/aten/src/ATen/core/ivalue.cpp +++ b/aten/src/ATen/core/ivalue.cpp @@ -22,7 +22,7 @@ namespace ivalue { // This is in ivalue.cpp because we need to access Type::annotation_str, which // is declared in jit_type.h -void checkCustomClassType(TypePtr expected_type, TypePtr actual_type) { +void checkCustomClassType(const Type* expected_type, const Type* actual_type) { // NB: doing pointer comparison here // If in the future there ever arises a need to call operator== on custom class // Type's, this needs to be changed! diff --git a/aten/src/ATen/core/ivalue.h b/aten/src/ATen/core/ivalue.h index 9ea18dc8482d..d2e72933b532 100644 --- a/aten/src/ATen/core/ivalue.h +++ b/aten/src/ATen/core/ivalue.h @@ -949,8 +949,8 @@ TORCH_API ska::flat_hash_map& getCustomClassTypeMap(); template -c10::ClassTypePtr getCustomClassType() { - auto tmap = c10::getCustomClassTypeMap(); +c10::ClassTypePtr getCustomClassTypeImpl() { + auto& tmap = c10::getCustomClassTypeMap(); auto res = tmap.find(std::type_index(typeid(T))); if (res == tmap.end()) { throw c10::Error("Can't find class id in custom class type map", ""); @@ -959,9 +959,13 @@ c10::ClassTypePtr getCustomClassType() { } template -inline bool isCustomClassRegistered() { - auto tmap = c10::getCustomClassTypeMap(); - return tmap.find(std::type_index(typeid(T))) != tmap.end(); +const c10::ClassTypePtr& getCustomClassType() { + // Classes are never unregistered from getCustomClassTypeMap and the + // hash lookup can be a hot path, so just cache. + // For the same reason, it's fine If this ends up getting duplicated across + // DSO boundaries for whatever reason. + static c10::ClassTypePtr cache = getCustomClassTypeImpl(); + return cache; } TORCH_API std::unordered_map>& diff --git a/aten/src/ATen/core/ivalue_inl.h b/aten/src/ATen/core/ivalue_inl.h index 8858d0047abd..b3b53aed994c 100644 --- a/aten/src/ATen/core/ivalue_inl.h +++ b/aten/src/ATen/core/ivalue_inl.h @@ -172,7 +172,7 @@ inline at::Generator IValue::toGenerator() const& { namespace ivalue { void CAFFE2_API -checkCustomClassType(TypePtr expected_type, TypePtr actual_type); +checkCustomClassType(const Type* expected_type, const Type* actual_type); template using Shared = c10::intrusive_ptr; @@ -820,8 +820,8 @@ c10::intrusive_ptr IValue::toCustomClass() && { obj->slots().size() == 1, "Tried to cast IValue to custom class but it did " "not contain a custom class!"); - auto expected_type = c10::getCustomClassType>(); - ivalue::checkCustomClassType(expected_type, type()); + const Type* expected_type = c10::getCustomClassType>().get(); + ivalue::checkCustomClassType(expected_type, type().get()); auto userObj = c10::static_intrusive_pointer_cast(obj->getSlot(0).toCapsule()); return userObj; @@ -838,8 +838,8 @@ c10::intrusive_ptr IValue::toCustomClass() const& { obj->slots().size() == 1, "Tried to cast IValue to custom class but it did " "not contain a custom class!"); - auto expected_type = c10::getCustomClassType>(); - ivalue::checkCustomClassType(expected_type, type()); + const Type* expected_type = c10::getCustomClassType>().get(); + ivalue::checkCustomClassType(expected_type, type().get()); auto userObj = c10::static_intrusive_pointer_cast(obj->getSlot(0).toCapsule()); return userObj; @@ -1149,13 +1149,16 @@ template < typename T, std::enable_if_t::value, int>> IValue::IValue(c10::intrusive_ptr custom_class) { - if (!c10::isCustomClassRegistered>()) { - throw c10::Error( - "Trying to instantiate a class that isn't a registered custom class: " + - std::string(c10::util::get_fully_qualified_type_name()), - ""); - } - auto classType = c10::getCustomClassType>(); + TypePtr classType = []() { + try { + return c10::getCustomClassType>(); + } catch (const c10::Error&) { + throw c10::Error( + "Trying to instantiate a class that isn't a registered custom class: " + + std::string(c10::util::get_fully_qualified_type_name()), + ""); + } + }(); auto ivalue_obj = c10::ivalue::Object::create( c10::StrongTypePtr(nullptr, classType), /*num_slots=*/1); ivalue_obj->setSlot(0, IValue::make_capsule(std::move(custom_class))); diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h index 1736ea91d71e..7fcd5c2d17e9 100644 --- a/aten/src/ATen/core/jit_type.h +++ b/aten/src/ATen/core/jit_type.h @@ -1727,13 +1727,18 @@ namespace detail { template struct getTypePtr_ final { static TypePtr call() { - TORCH_CHECK( - isCustomClassRegistered(), - "Type ", - c10::util::get_fully_qualified_type_name(), - " could not be converted to any of the known types." - ); - auto res = getCustomClassType(); + TypePtr res = []() { + try { + return getCustomClassType(); + } catch(const c10::Error&) { + TORCH_CHECK( + false, + "Type ", + c10::util::get_fully_qualified_type_name(), + " could not be converted to any of the known types." + ); + } + }(); return std::dynamic_pointer_cast(std::move(res)); } }; From df027bfd2c34503006ab985348e7205799c3f0fc Mon Sep 17 00:00:00 2001 From: Pritam Damania Date: Fri, 11 Dec 2020 14:51:51 -0800 Subject: [PATCH 22/33] Modify Pipe to return an RRef. (#47829) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47829 As per proposal in https://github.com/pytorch/pytorch/issues/44827, the API needs to return an RRef to support inter-host pipelining. For now, we just return a local RRef and only support pipeline on a single host. But having this change in the API upfront ensures we don't make any BC breaking changes later. ghstack-source-id: 118366784 Test Plan: waitforbuildbot Reviewed By: rohan-varma Differential Revision: D24914022 fbshipit-source-id: e711e7d12efa45645f752f0e5e776a3d845f3ef5 --- test/distributed/_pipeline/sync/conftest.py | 16 +++++ .../_pipeline/sync/skip/test_gpipe.py | 12 ++-- .../_pipeline/sync/skip/test_leak.py | 6 +- test/distributed/_pipeline/sync/test_bugs.py | 13 ++-- .../_pipeline/sync/test_inplace.py | 12 ++-- test/distributed/_pipeline/sync/test_pipe.py | 66 +++++++++---------- .../_pipeline/sync/test_transparency.py | 4 +- test/run_test.py | 22 +++++++ torch/distributed/_pipeline/sync/pipe.py | 11 ++-- 9 files changed, 101 insertions(+), 61 deletions(-) diff --git a/test/distributed/_pipeline/sync/conftest.py b/test/distributed/_pipeline/sync/conftest.py index 315431d0b644..561c41d11350 100644 --- a/test/distributed/_pipeline/sync/conftest.py +++ b/test/distributed/_pipeline/sync/conftest.py @@ -5,7 +5,9 @@ # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. import pytest +import tempfile import torch +from torch.distributed import rpc @pytest.fixture(autouse=True) @@ -35,3 +37,17 @@ def cuda_sleep(seconds): def pytest_report_header(): return f"torch: {torch.__version__}" + +@pytest.fixture +def setup_rpc(scope="session"): + file = tempfile.NamedTemporaryFile() + rpc.init_rpc( + name="worker0", + rank=0, + world_size=1, + rpc_backend_options=rpc.TensorPipeRpcBackendOptions( + init_method="file://{}".format(file.name), + ) + ) + yield + rpc.shutdown() diff --git a/test/distributed/_pipeline/sync/skip/test_gpipe.py b/test/distributed/_pipeline/sync/skip/test_gpipe.py index 96ecd84e0d18..90ecd7613d67 100644 --- a/test/distributed/_pipeline/sync/skip/test_gpipe.py +++ b/test/distributed/_pipeline/sync/skip/test_gpipe.py @@ -17,7 +17,7 @@ @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") @pytest.mark.parametrize("balance", [[3], [1, 2], [2, 1], [1, 1, 1]], ids=["3", "1:2", "2:1", "1:1:1"]) @pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"]) -def test_1to3(balance, checkpoint): +def test_1to3(balance, checkpoint, setup_rpc): if torch.cuda.device_count() < len(balance): pytest.skip("at least %d cuda devices required" % len(balance)) @@ -61,14 +61,14 @@ def forward(self, input): input = torch.rand(30, 3, 224, 224, device=in_device, requires_grad=True) output = model(input) - loss = output.mean() + loss = output.local_value().mean() loss.backward() - assert torch.allclose(output.norm(), torch.tensor(1039.0, device=out_device), atol=6e-1) + assert torch.allclose(output.local_value().norm(), torch.tensor(1039.0, device=out_device), atol=6e-1) assert torch.allclose(input.grad.norm(), torch.tensor(0.0004533053, device=in_device)) -def test_none_skip(): +def test_none_skip(setup_rpc): @skippable(stash=["none"]) class Stash(nn.Module): def forward(self, input): @@ -102,7 +102,7 @@ def assert_grad_fn_is_not_portal(grad_fn, visited=None): for next_grad_fn, _ in grad_fn.next_functions: assert_grad_fn_is_not_portal(next_grad_fn, visited) - assert_grad_fn_is_not_portal(output.grad_fn) + assert_grad_fn_is_not_portal(output.local_value().grad_fn) - output.sum().backward() + output.local_value().sum().backward() assert input.grad.mean().item() == 1 diff --git a/test/distributed/_pipeline/sync/skip/test_leak.py b/test/distributed/_pipeline/sync/skip/test_leak.py index 31c4ea13b9f1..7d03a4e9db49 100644 --- a/test/distributed/_pipeline/sync/skip/test_leak.py +++ b/test/distributed/_pipeline/sync/skip/test_leak.py @@ -29,7 +29,7 @@ def forward(self, input): @pytest.mark.parametrize("train", [True, False], ids=["train", "eval"]) @pytest.mark.parametrize("checkpoint", ["always", "except_last", "never"]) -def test_delete_portal_tensor(train, checkpoint): +def test_delete_portal_tensor(train, checkpoint, setup_rpc): # Without checkpointing: # +- Stash --+ +--- Pop ----+ - - - layers # | 2,blue,1 |--| 1,orange,0 | - - - tensor_life and portal function @@ -97,7 +97,7 @@ def forward(self, input): if train: model.train() - output = model(input) + output = model(input).local_value() output.norm().backward() else: model.eval() @@ -106,7 +106,7 @@ def forward(self, input): @pytest.mark.parametrize("train", [True, False], ids=["train", "eval"]) -def test_no_portal_without_pipe(train, monkeypatch): +def test_no_portal_without_pipe(train, monkeypatch, setup_rpc): def deny(*args, **kwargs): raise AssertionError("tried to create Portal without Pipe") diff --git a/test/distributed/_pipeline/sync/test_bugs.py b/test/distributed/_pipeline/sync/test_bugs.py index 4f5346a837b5..a66b7d006ae1 100644 --- a/test/distributed/_pipeline/sync/test_bugs.py +++ b/test/distributed/_pipeline/sync/test_bugs.py @@ -12,7 +12,7 @@ from torch.distributed._pipeline.sync import Pipe -def test_python_autograd_function(): +def test_python_autograd_function(setup_rpc): # A Python autograd function might fail with this error: # # RuntimeError: Returning Variables sharing storage with other Variables @@ -41,10 +41,10 @@ def forward(self, input): x = torch.rand(42) y = model(x) - assert torch.allclose(x, y) + assert torch.allclose(x, y.local_value()) -def test_exception_no_hang(): +def test_exception_no_hang(setup_rpc): # In v0.0.2, once a failed partition receives a normal message # (non-closing) for the next micro-batch, a hang occured. The reason was # that a failed partition didn't call in_queue.task_done() on a normal @@ -69,7 +69,7 @@ def forward(self, x): @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="2 cuda devices required") -def test_tuple_wait(cuda_sleep): +def test_tuple_wait(cuda_sleep, setup_rpc): # In v0.0.3, Wait is applied to only the first tensor on a micro-batch. # Under this behavior, if checkpointing was disabled, there's a possibility # that gradient accumulations on other tensors are not synchronized @@ -113,7 +113,7 @@ def forward(self, triple): b = torch.rand(1024, 3, 32, 32, device=0, requires_grad=True) y = model((a, b)) - y.norm().backward() + y.local_value().norm().backward() torch.cuda.synchronize(0) torch.cuda.synchronize(1) @@ -121,7 +121,7 @@ def forward(self, triple): assert torch.isclose(b.grad.norm().cpu(), torch.tensor(5.000)) -def test_parallel_randoms(): +def test_parallel_randoms(setup_rpc): class Dropouts(nn.Module): def forward(self, x): for _ in range(100): @@ -133,6 +133,7 @@ def forward(self, x): x = torch.rand(10, 10, requires_grad=True) model = Pipe(model, chunks=10, checkpoint="always") y = model(x) + y = y.local_value() y.norm().backward() assert y.to(torch.bool).tolist() == x.grad.to(torch.bool).tolist() diff --git a/test/distributed/_pipeline/sync/test_inplace.py b/test/distributed/_pipeline/sync/test_inplace.py index 17b3dac4eca8..3b842dbfb9ab 100644 --- a/test/distributed/_pipeline/sync/test_inplace.py +++ b/test/distributed/_pipeline/sync/test_inplace.py @@ -11,12 +11,12 @@ from torch.distributed._pipeline.sync import Pipe -def test_inplace_on_requires_grad(): +def test_inplace_on_requires_grad(setup_rpc): model = nn.Sequential(nn.Linear(1, 1), nn.ReLU(inplace=True)) model = Pipe(model, checkpoint="always") x = torch.rand(1) - y = model(x) + y = model(x).local_value() message = r"a leaf Variable that requires grad .* used in an in-place operation." with pytest.raises(RuntimeError, match=message): @@ -24,14 +24,14 @@ def test_inplace_on_requires_grad(): @pytest.mark.xfail(strict=True) -def test_inplace_on_not_requires_grad(): +def test_inplace_on_not_requires_grad(setup_rpc): # In-place operation on a tensor not requiring grad doesn't cause a # RuntimeError. Currently, we cannot detect this case. model = nn.Sequential(nn.ReLU(inplace=True)) model = Pipe(model, [1], devices=["cpu"], checkpoint="always") x = torch.rand(1) - y = model(x) + y = model(x).local_value() del model message = r"a leaf Variable that requires grad .* used in an in-place operation." @@ -40,7 +40,7 @@ def test_inplace_on_not_requires_grad(): @pytest.mark.xfail(strict=True) -def test_inplace_incorrect_grad(): +def test_inplace_incorrect_grad(setup_rpc): class M(nn.Module): def forward(self, foo_bar): # 'foo' requires grad but 'bar' does not. In-place operation on @@ -62,7 +62,7 @@ def forward(self, foo_bar): foo = torch.tensor([1.0], requires_grad=True) bar = torch.tensor([1.0]) - output = model((foo, bar)) + output = model((foo, bar)).local_value() del model output.backward() diff --git a/test/distributed/_pipeline/sync/test_pipe.py b/test/distributed/_pipeline/sync/test_pipe.py index c0992c7bc0ed..8b87fa3d31f6 100644 --- a/test/distributed/_pipeline/sync/test_pipe.py +++ b/test/distributed/_pipeline/sync/test_pipe.py @@ -68,7 +68,7 @@ def test_chunks_less_than_1(): with pytest.raises(ValueError): Pipe(model, chunks=-1) -def test_batch_size_indivisible(): +def test_batch_size_indivisible(setup_rpc): model = nn.Sequential(nn.Linear(1, 1)) model = Pipe(model, chunks=4) @@ -79,7 +79,7 @@ def test_batch_size_indivisible(): assert not record -def test_batch_size_small(): +def test_batch_size_small(setup_rpc): model = nn.Sequential(nn.Linear(1, 1)) model = Pipe(model, chunks=4) @@ -90,7 +90,7 @@ def test_batch_size_small(): assert not record -def test_checkpoint_mode(): +def test_checkpoint_mode(setup_rpc): def count_grad_fn(grad_fn, name, visited=None): if visited is None: visited = set() @@ -119,9 +119,9 @@ def count_grad_fn(grad_fn, name, visited=None): except_last_output = except_last(input) never_output = never(input) - assert count_grad_fn(always_output.grad_fn, "CheckpointBackward") == 2 - assert count_grad_fn(except_last_output.grad_fn, "CheckpointBackward") == 1 - assert count_grad_fn(never_output.grad_fn, "CheckpointBackward") == 0 + assert count_grad_fn(always_output.local_value().grad_fn, "CheckpointBackward") == 2 + assert count_grad_fn(except_last_output.local_value().grad_fn, "CheckpointBackward") == 1 + assert count_grad_fn(never_output.local_value().grad_fn, "CheckpointBackward") == 0 def test_checkpoint_mode_invalid(): @@ -140,7 +140,7 @@ def test_checkpoint_mode_when_chunks_1(): Pipe(model, chunks=1, checkpoint="never") -def test_checkpoint_eval(): +def test_checkpoint_eval(setup_rpc): model = nn.Sequential(nn.Linear(1, 1)) model = Pipe(model, chunks=2) input = torch.rand(2, 1) @@ -157,16 +157,16 @@ def find_grad_fn(grad_fn, name): model.train() train_output = model(input) - assert find_grad_fn(train_output.grad_fn, "CheckpointBackward") - assert find_grad_fn(train_output.grad_fn, "RecomputeBackward") + assert find_grad_fn(train_output.local_value().grad_fn, "CheckpointBackward") + assert find_grad_fn(train_output.local_value().grad_fn, "RecomputeBackward") model.eval() eval_output = model(input) - assert not find_grad_fn(eval_output.grad_fn, "CheckpointBackward") - assert not find_grad_fn(eval_output.grad_fn, "RecomputeBackward") + assert not find_grad_fn(eval_output.local_value().grad_fn, "CheckpointBackward") + assert not find_grad_fn(eval_output.local_value().grad_fn, "RecomputeBackward") -def test_checkpoint_non_float_input(): +def test_checkpoint_non_float_input(setup_rpc): class ForkNonFloat(nn.Module): def forward(self, input): return (input * 2, torch.tensor([False])) @@ -183,7 +183,7 @@ def forward(self, input): output.backward() -def test_no_grad(): +def test_no_grad(setup_rpc): model = nn.Sequential(nn.Linear(1, 1)) model = Pipe(model, chunks=2) input = torch.rand(2, 1) @@ -206,7 +206,7 @@ def hook(module, input, output): assert latent.grad_fn is None -def test_exception(): +def test_exception(setup_rpc): class ExpectedException(Exception): pass @@ -221,7 +221,7 @@ def forward(self, *_): model(torch.rand(1)) -def test_exception_early_stop_asap(): +def test_exception_early_stop_asap(setup_rpc): """Even the first partitions have finished to process, the partition before the failed partition should be killed as soon as possible. """ @@ -258,7 +258,7 @@ def forward(self, x): assert counter == 2 -def test_input_pair(): +def test_input_pair(setup_rpc): class Two(nn.Module): def __init__(self): super().__init__() @@ -275,7 +275,7 @@ def forward(self, a_and_b): a = torch.rand(10, 1, requires_grad=True) b = torch.rand(10, 1, requires_grad=True) - a_out, b_out = model((a, b)) + a_out, b_out = model((a, b)).local_value() loss = (a_out + b_out).mean() loss.backward() @@ -283,7 +283,7 @@ def forward(self, a_and_b): assert b.grad is not None -def test_input_singleton(): +def test_input_singleton(setup_rpc): class One(nn.Module): def __init__(self): super().__init__() @@ -298,7 +298,7 @@ def forward(self, only_a): a = torch.rand(10, 1, requires_grad=True) - (a_out,) = model((a,)) + (a_out,) = model((a,)).local_value() loss = a_out.mean() loss.backward() @@ -306,7 +306,7 @@ def forward(self, only_a): assert a.grad is not None -def test_input_varargs(): +def test_input_varargs(setup_rpc): model = nn.Sequential(nn.Linear(1, 1)) model = Pipe(model) @@ -318,7 +318,7 @@ def test_input_varargs(): model(a, b) -def test_non_tensor(): +def test_non_tensor(setup_rpc): class NonTensor(nn.Module): def forward(self, _): return "hello" @@ -336,7 +336,7 @@ def forward(self, _): model("hello") -def test_non_tensor_tuple(): +def test_non_tensor_tuple(setup_rpc): class NonTensorTuple(nn.Module): def forward(self, x): return (x, "hello") @@ -355,7 +355,7 @@ def forward(self, x): @pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"]) -def test_deferred_batch_norm(checkpoint): +def test_deferred_batch_norm(checkpoint, setup_rpc): bn = nn.BatchNorm2d(3) pipe_bn = deepcopy(bn) pipe = Pipe( @@ -363,7 +363,7 @@ def test_deferred_batch_norm(checkpoint): ) x = torch.rand(4, 3, 10, 10) - pipe(x).mean().backward() + pipe(x).local_value().mean().backward() bn(x).mean().backward() assert torch.allclose(pipe[0].running_mean, bn.running_mean, atol=1e-4) @@ -371,7 +371,7 @@ def test_deferred_batch_norm(checkpoint): @pytest.mark.parametrize("checkpoint", ["never", "always"]) -def test_deferred_batch_norm_params(checkpoint): +def test_deferred_batch_norm_params(checkpoint, setup_rpc): bn = nn.BatchNorm2d(3) pipe_bn = deepcopy(bn) pipe = Pipe( @@ -379,7 +379,7 @@ def test_deferred_batch_norm_params(checkpoint): ) x = torch.rand(4, 3, 10, 10) - pipe(x).mean().backward() + pipe(x).local_value().mean().backward() bn(x).mean().backward() assert pipe[0].weight.grad is not None @@ -455,13 +455,13 @@ def test_deny_moving(): model.to(dtype=torch.float) -def test_empty_module(): +def test_empty_module(setup_rpc): # Empty sequential module is not illegal. model = nn.Sequential() model = Pipe(model) - assert model(torch.tensor(42)) == torch.tensor(42) - assert model((torch.tensor(42),)) == (torch.tensor(42),) + assert model(torch.tensor(42)).local_value() == torch.tensor(42) + assert model((torch.tensor(42),)).local_value() == (torch.tensor(42),) # But only tensor or tensors is legal in Pipe. with pytest.raises(TypeError): @@ -518,7 +518,7 @@ def __init__(self, param1, param2): @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Need atleast two GPUs") -def test_verify_nested_modules(): +def test_verify_nested_modules(setup_rpc): model = nn.Sequential( nn.Sequential( nn.Linear(32, 16).cuda(0), @@ -532,8 +532,8 @@ def test_verify_nested_modules(): pipe = Pipe(model) out = pipe(torch.rand(10, 32).cuda(0)) - assert out.device == torch.device("cuda:1") - assert out.size() == torch.Size([10, 2]) + assert out.local_value().device == torch.device("cuda:1") + assert out.local_value().size() == torch.Size([10, 2]) def test_verify_module_duplicate_parameters_on_same_device(): class Surrogate(nn.Module): @@ -547,7 +547,7 @@ def __init__(self, module): Pipe(model) -def test_forward_lockstep(): +def test_forward_lockstep(setup_rpc): timeline = [] class DelayedLog(nn.Module): diff --git a/test/distributed/_pipeline/sync/test_transparency.py b/test/distributed/_pipeline/sync/test_transparency.py index 3d2c77e8fef4..56ad86de081b 100644 --- a/test/distributed/_pipeline/sync/test_transparency.py +++ b/test/distributed/_pipeline/sync/test_transparency.py @@ -10,7 +10,7 @@ from torch.distributed._pipeline.sync import Pipe -def test_simple_linears(): +def test_simple_linears(setup_rpc): def sum_grad(parameters): return sum([p.grad.sum() for p in parameters if p.grad is not None]) @@ -33,7 +33,7 @@ def zero_grad(parameters): # With Pipe model = Pipe(model, chunks=4) - outputs = model(inputs) + outputs = model(inputs).local_value() loss = outputs.mean() loss.backward() diff --git a/test/run_test.py b/test/run_test.py index 3687459a4a70..54cc33ebc484 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -162,6 +162,28 @@ 'distributed/rpc/test_process_group_agent', 'distributed/rpc/test_tensorpipe_agent', 'distributed/test_distributed_fork', + 'distributed/_pipeline/sync/skip/test_api', + 'distributed/_pipeline/sync/skip/test_gpipe', + 'distributed/_pipeline/sync/skip/test_inspect_skip_layout', + 'distributed/_pipeline/sync/skip/test_leak', + 'distributed/_pipeline/sync/skip/test_portal', + 'distributed/_pipeline/sync/skip/test_stash_pop', + 'distributed/_pipeline/sync/skip/test_tracker', + 'distributed/_pipeline/sync/skip/test_verify_skippables', + 'distributed/_pipeline/sync/test_balance', + 'distributed/_pipeline/sync/test_bugs', + 'distributed/_pipeline/sync/test_checkpoint', + 'distributed/_pipeline/sync/test_copy', + 'distributed/_pipeline/sync/test_deferred_batch_norm', + 'distributed/_pipeline/sync/test_dependency', + 'distributed/_pipeline/sync/test_inplace', + 'distributed/_pipeline/sync/test_microbatch', + 'distributed/_pipeline/sync/test_phony', + 'distributed/_pipeline/sync/test_pipe', + 'distributed/_pipeline/sync/test_pipeline', + 'distributed/_pipeline/sync/test_stream', + 'distributed/_pipeline/sync/test_transparency', + 'distributed/_pipeline/sync/test_worker', ] ROCM_BLOCKLIST = [ diff --git a/torch/distributed/_pipeline/sync/pipe.py b/torch/distributed/_pipeline/sync/pipe.py index 92a3c301cc39..a097e8aa1a9e 100644 --- a/torch/distributed/_pipeline/sync/pipe.py +++ b/torch/distributed/_pipeline/sync/pipe.py @@ -10,6 +10,7 @@ import torch from torch import Tensor, nn +from torch.distributed.rpc import RRef import torch.autograd import torch.cuda @@ -305,7 +306,7 @@ def _ensure_copy_streams(self) -> List[List[AbstractStream]]: return self._copy_streams - def forward(self, input: TensorOrTensors) -> TensorOrTensors: # type: ignore + def forward(self, input: TensorOrTensors) -> RRef[TensorOrTensors]: # type: ignore """:class:`Pipe` is a fairly transparent module wrapper. It doesn't modify the input and output signature of the underlying module. But there's type restriction. Input and output have to be a @@ -313,10 +314,10 @@ def forward(self, input: TensorOrTensors) -> TensorOrTensors: # type: ignore applied at partition boundaries too. Args: - input (torch.Tensor or tensors): input mini-batch + input (torch.Tensor or Tuple[torch.Tensor, ...]): input mini-batch Returns: - tensor or tensors: output mini-batch + :class:`~torch.distributed.rpc.RRef` to the output of the mini-batch Raises: TypeError: input is not a tensor or tensors. @@ -326,7 +327,7 @@ def forward(self, input: TensorOrTensors) -> TensorOrTensors: # type: ignore if not self.devices: # Empty sequential module is not illegal. - return input + return RRef(input) # Divide a mini-batch into micro-batches. batches = microbatch.scatter(input, self.chunks) @@ -336,4 +337,4 @@ def forward(self, input: TensorOrTensors) -> TensorOrTensors: # type: ignore # Merge the micro-batches into one mini-batch. output = microbatch.gather(batches) - return output + return RRef(output) From 2f359e7d55f8de14fcd74231fc0f256d9fd8c607 Mon Sep 17 00:00:00 2001 From: Pritam Damania Date: Fri, 11 Dec 2020 14:57:16 -0800 Subject: [PATCH 23/33] Add tensorpipe agent tests to multigpu tests. (#49210) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49210 The RPC tests use multiple gpus in some cases (ex: DDP + RPC and Pipe + DDP). We should enable multigpu tests for this purpose. ghstack-source-id: 118366595 Test Plan: waitforbuildbot Reviewed By: rohan-varma Differential Revision: D25485506 fbshipit-source-id: eabbf442471ebc700b5986bc751879b9cf72b752 --- .jenkins/pytorch/multigpu-test.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/.jenkins/pytorch/multigpu-test.sh b/.jenkins/pytorch/multigpu-test.sh index 9a2c486610c4..fdf3c03e7f67 100755 --- a/.jenkins/pytorch/multigpu-test.sh +++ b/.jenkins/pytorch/multigpu-test.sh @@ -21,4 +21,5 @@ time python test/run_test.py --verbose -i distributed/test_jit_c10d time python test/run_test.py --verbose -i distributed/test_distributed_fork time python test/run_test.py --verbose -i distributed/test_c10d time python test/run_test.py --verbose -i distributed/test_c10d_spawn +time python test/run_test.py --verbose -i distributed/rpc/test_tensorpipe_agent assert_git_not_dirty From 53aa9b8c829edfc4194259f2c14b194171074cf9 Mon Sep 17 00:00:00 2001 From: James Reed Date: Fri, 11 Dec 2020 15:43:04 -0800 Subject: [PATCH 24/33] [FX] Move none assignments to same line (#49209) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49209 Test Plan: Imported from OSS Reviewed By: Chillee Differential Revision: D25484975 Pulled By: jamesr66a fbshipit-source-id: 44207be878f95ec9420e87af79833191d5cc0c7e --- torch/fx/graph.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/torch/fx/graph.py b/torch/fx/graph.py index ca4b8d64bb0e..f8bc96b73c40 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -617,10 +617,15 @@ def delete_unused_values(user : Node): not used in the remainder of the code are freed and the memory usage of the code is optimal. """ + if user.op == 'output': + body.append('\n') + return nodes_to_delete = user_to_last_uses.get(user, []) if len(nodes_to_delete): to_delete_str = ' = '.join([n.name for n in nodes_to_delete] + ['None']) - body.append(f'{to_delete_str}\n') + body.append(f'; {to_delete_str}\n') + else: + body.append('\n') def emit_node(node : Node): if node.op == 'placeholder': @@ -630,20 +635,20 @@ def emit_node(node : Node): free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}') raw_name = node.target.replace('*', '') if raw_name != node.name: - body.append(f'{node.name} = {raw_name}\n') + body.append(f'{node.name} = {raw_name}') return elif node.op == 'call_method': assert isinstance(node.target, str) body.append( f'{node.name} = {_format_target(repr(node.args[0]), node.target)}' - f'({_format_args(node.args[1:], node.kwargs)})\n') + f'({_format_args(node.args[1:], node.kwargs)})') return elif node.op == 'call_function': assert callable(node.target) # pretty print operators if node.target.__module__ == '_operator' and node.target.__name__ in magic_methods: assert isinstance(node.args, tuple) - body.append(f'{node.name} = {magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}\n') + body.append(f'{node.name} = {magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}') return qualified_name = get_qualified_name(node.target) register_modules_used(qualified_name) @@ -652,26 +657,28 @@ def emit_node(node : Node): isinstance(node.args[1], str) and \ node.args[1].isidentifier(): # pretty print attribute access - body.append(f'{node.name} = {_format_target(repr(node.args[0]), node.args[1])}\n') + body.append(f'{node.name} = {_format_target(repr(node.args[0]), node.args[1])}') return - body.append(f'{node.name} = {qualified_name}({_format_args(node.args, node.kwargs)})\n') + body.append(f'{node.name} = {qualified_name}({_format_args(node.args, node.kwargs)})') return elif node.op == 'call_module': assert isinstance(node.target, str) - body.append(f'{node.name} = {_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})\n') + body.append(f'{node.name} = {_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})') return elif node.op == 'get_attr': assert isinstance(node.target, str) - body.append(f'{node.name} = {_format_target(root_module, node.target)}\n') + body.append(f'{node.name} = {_format_target(root_module, node.target)}') return elif node.op == 'output': if node.type is not None: maybe_return_annotation = f" -> {type_repr(node.type)}" - body.append(f'return {repr(node.args[0])}\n') + body.append(f'return {repr(node.args[0])}') return raise NotImplementedError(f'node: {node.op} {node.target}') for node in self.nodes: + # NOTE: emit_node does not emit a string with newline. It depends + # on delete_unused_values to append one emit_node(node) delete_unused_values(node) From bfce69d6200ecf1261bf8d45657c802c56317365 Mon Sep 17 00:00:00 2001 From: Natalia Gimelshein Date: Fri, 11 Dec 2020 15:45:26 -0800 Subject: [PATCH 25/33] inline `has` function for DispatchKeySet (#49191) Summary: inlines `has` function for DispatchKeySet, that is frequently used in TensorImpl in calls such as `is_sparse`, `is_cuda` etc. This increases `empty` instruction count (1853228 -> 1937428) without appreciable effect on runtime, and noticeably reduces instruction counts for `copy_` and friends that have to rely on `is_sparse`, `is_cuda` and the like a lot to decide which path to take (3269114 -> 2634114). Pull Request resolved: https://github.com/pytorch/pytorch/pull/49191 Reviewed By: H-Huang Differential Revision: D25483011 Pulled By: ngimel fbshipit-source-id: 2f3ab83e2c836a726b9284ffc50d6ecf3701aada --- c10/core/DispatchKeySet.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/c10/core/DispatchKeySet.h b/c10/core/DispatchKeySet.h index 1e9d85211f6d..486272ece92e 100644 --- a/c10/core/DispatchKeySet.h +++ b/c10/core/DispatchKeySet.h @@ -61,8 +61,8 @@ class DispatchKeySet final { } } // Test if a DispatchKey is in the set - bool has(DispatchKey t) const { - TORCH_INTERNAL_ASSERT(t != DispatchKey::Undefined); + bool inline has(DispatchKey t) const { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(t != DispatchKey::Undefined); return static_cast(repr_ & DispatchKeySet(t).repr_); } // Test if DispatchKeySet is a superset of ks. From 5716b7db72e8f66b7b2ab312cb3623b87aeb89d8 Mon Sep 17 00:00:00 2001 From: Iurii Zdebskyi Date: Fri, 11 Dec 2020 15:46:00 -0800 Subject: [PATCH 26/33] Enabled Scalar lists (#48222) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/48222 Test Plan: Imported from OSS Reviewed By: ngimel Differential Revision: D25074765 Pulled By: izdeby fbshipit-source-id: 96ebe3c9907178c9338c03fb7993b2ecb26db8f4 --- .../impl/make_boxed_from_unboxed_functor.h | 8 ---- tools/codegen/api/python.py | 5 ++- tools/jit/gen_unboxing_wrappers.py | 2 + tools/pyi/gen_pyi.py | 1 - torch/csrc/utils/python_arg_parser.cpp | 44 ++++++++++++++----- torch/csrc/utils/python_arg_parser.h | 17 ++++++- 6 files changed, 56 insertions(+), 21 deletions(-) diff --git a/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h b/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h index 3dfb4ee4f04b..3d040387d3bb 100644 --- a/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h +++ b/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h @@ -119,14 +119,6 @@ namespace impl { "You tried to register a kernel with an unsupported input type: List. Please use List, List or Tensor instead."); }; - template - struct assert_is_valid_input_type, AllowDeprecatedTypes> - : assert_is_valid_input_type { - static_assert(!std::is_same::value, - "You tried to register a kernel with an unsupported input type: std::vector. Please use List, List or Tensor instead."); - // TODO static_assert(AllowDeprecatedTypes, "You tried to register a kernel with an unsupported input type: std::vector. Please use List instead."); - }; - template struct assert_is_valid_input_type, AllowDeprecatedTypes> : assert_is_valid_input_type { diff --git a/tools/codegen/api/python.py b/tools/codegen/api/python.py index 10483e2e3d76..c78fe23150e8 100644 --- a/tools/codegen/api/python.py +++ b/tools/codegen/api/python.py @@ -620,6 +620,8 @@ def argument_type_str(t: Type, *, simple_type: bool = False) -> str: return f'IntArrayRef[{size}]' if size is not None else 'IntArrayRef' elif str(t.elem) == 'Tensor': return f'TensorList[{size}]' if size is not None else 'TensorList' + elif str(t.elem) == 'Scalar': + return f'ScalarList[{size}]' if size is not None else 'ScalarList' elif str(t.elem) == 'Tensor?': if simple_type: return 'TensorList' @@ -1063,7 +1065,8 @@ def arg_parser_unpack_method(t: Type, has_default: bool) -> str: return 'intlist' elif str(t) == 'float[]': return 'doublelist' - + elif str(t) == 'Scalar[]': + return 'scalarlist' raise RuntimeError(f'type \'{t}\' is not supported by PythonArgParser') # Return RHS expression for python argument using PythonArgParser output. diff --git a/tools/jit/gen_unboxing_wrappers.py b/tools/jit/gen_unboxing_wrappers.py index f2896fac7f22..267b5a3b221a 100644 --- a/tools/jit/gen_unboxing_wrappers.py +++ b/tools/jit/gen_unboxing_wrappers.py @@ -49,6 +49,7 @@ 'std::string': 'str', 'std::string?': 'str?', 'Scalar': 'Scalar', + 'ScalarList': 'Scalar[]', 'MemoryFormat': 'MemoryFormat', 'MemoryFormat?': 'MemoryFormat?', 'QScheme': 'QScheme', @@ -131,6 +132,7 @@ def jit_type_of(arg): 'Tensor?': 'toOptionalTensor({})', 'Tensor?[]': 'toListOfOptionalTensor({})', 'TensorList': '{}.toTensorVector()', + 'ScalarList': '{}.toScalarVector()', 'bool': '{}.toBool()', 'bool?': '{}.toOptional()', 'double': '{}.toDouble()', diff --git a/tools/pyi/gen_pyi.py b/tools/pyi/gen_pyi.py index dad150fa0ad5..d2073bec9a27 100644 --- a/tools/pyi/gen_pyi.py +++ b/tools/pyi/gen_pyi.py @@ -122,7 +122,6 @@ def should_bind_method(python_func: PythonSignatureNativeFunctionPair) -> bool: 'floor_divide', 'floor_divide_', 'floor_divide_out', ] - binary_ops = ('add', 'sub', 'mul', 'div', 'pow', 'lshift', 'rshift', 'mod', 'truediv', 'matmul', 'floordiv', 'radd', 'rsub', 'rmul', 'rtruediv', 'rfloordiv', 'rpow', # reverse arithmetic diff --git a/torch/csrc/utils/python_arg_parser.cpp b/torch/csrc/utils/python_arg_parser.cpp index 950e7d9fb82d..c7fdf844945e 100644 --- a/torch/csrc/utils/python_arg_parser.cpp +++ b/torch/csrc/utils/python_arg_parser.cpp @@ -39,6 +39,7 @@ static std::unordered_map type_map = { {"std::string", ParameterType::STRING}, {"Dimname", ParameterType::DIMNAME}, {"DimnameList", ParameterType::DIMNAME_LIST}, + {"ScalarList", ParameterType::SCALAR_LIST}, }; // Default arg name translations for compatibility with NumPy. @@ -348,13 +349,28 @@ bool is_tensor_and_append_overloaded(PyObject* obj, std::vector* ove return false; } -bool is_tensor_list_and_append_overloaded(PyObject* obj, std::vector* overloaded_args, int argnum, bool throw_error) { +bool is_scalar_list(PyObject* obj) { auto tuple = six::isTuple(obj); if (!(tuple || PyList_Check(obj))) { return false; } auto size = tuple ? PyTuple_GET_SIZE(obj) : PyList_GET_SIZE(obj); for (size_t idx = 0; idx < size; idx++) { + PyObject* iobj = tuple ? PyTuple_GET_ITEM(obj, idx) : PyList_GET_ITEM(obj, idx); + if (!THPUtils_checkScalar(iobj)) { + return false; + } + } + return true; +} + +bool is_tensor_list_and_append_overloaded(PyObject* obj, std::vector* overloaded_args, int argnum, bool throw_error) { + auto tuple = six::isTuple(obj); + if (!(tuple || PyList_Check(obj))) { + return false; + } + auto size = tuple ? PyTuple_GET_SIZE(obj) : PyList_GET_SIZE(obj); + for (long idx = 0; idx < size; idx++) { PyObject* iobj = tuple ? PyTuple_GET_ITEM(obj, idx) : PyList_GET_ITEM(obj, idx); if (!is_tensor_and_append_overloaded(iobj, overloaded_args)) { if (throw_error) { @@ -453,6 +469,9 @@ auto FunctionParameter::check(PyObject* obj, std::vector &overloaded return THPStream_Check(obj); case ParameterType::STRING: return THPUtils_checkString(obj); default: throw std::runtime_error("unknown parameter type"); + case ParameterType::SCALAR_LIST: { + return is_scalar_list(obj); + } } } @@ -478,6 +497,7 @@ std::string FunctionParameter::type_name() const { case ParameterType::STRING: return "str"; case ParameterType::DIMNAME: return "name"; case ParameterType::DIMNAME_LIST: return "tuple of names"; + case ParameterType::SCALAR_LIST: return "tuple of Scalars"; default: throw std::runtime_error("unknown parameter type"); } } @@ -1055,24 +1075,28 @@ at::Scalar PythonArgs::scalar_slow(int i) { signature.params[i].name, idx, var, jit::NumberType::get()); } + return scalar_slow(args[i]); +} + +at::Scalar PythonArgs::scalar_slow(PyObject* arg) { // Zero-dim tensors are converted to Scalars as-is. Note this doesn't currently // handle most NumPy scalar types except np.float64. - if (THPVariable_Check(args[i])) { - return ((THPVariable*)args[i])->cdata.item(); + if (THPVariable_Check(arg)) { + return ((THPVariable*)arg)->cdata.item(); } - if (THPUtils_checkLong(args[i])) { - return at::Scalar(static_cast(THPUtils_unpackLong(args[i]))); + if (THPUtils_checkLong(arg)) { + return at::Scalar(static_cast(THPUtils_unpackLong(arg))); } - if (PyBool_Check(args[i])) { - return at::Scalar(THPUtils_unpackBool(args[i])); + if (PyBool_Check(arg)) { + return at::Scalar(THPUtils_unpackBool(arg)); } - if (PyComplex_Check(args[i])) { - return at::Scalar(THPUtils_unpackComplexDouble(args[i])); + if (PyComplex_Check(arg)) { + return at::Scalar(THPUtils_unpackComplexDouble(arg)); } - return at::Scalar(THPUtils_unpackDouble(args[i])); + return at::Scalar(THPUtils_unpackDouble(arg)); } } // namespace torch diff --git a/torch/csrc/utils/python_arg_parser.h b/torch/csrc/utils/python_arg_parser.h index b0b81a9517da..ccf3ba6b42c4 100644 --- a/torch/csrc/utils/python_arg_parser.h +++ b/torch/csrc/utils/python_arg_parser.h @@ -80,7 +80,7 @@ namespace torch { enum class ParameterType { TENSOR, SCALAR, INT64, DOUBLE, COMPLEX, TENSOR_LIST, INT_LIST, GENERATOR, BOOL, STORAGE, PYOBJECT, SCALARTYPE, LAYOUT, MEMORY_FORMAT, DEVICE, STREAM, STRING, - DIMNAME, DIMNAME_LIST, QSCHEME, FLOAT_LIST + DIMNAME, DIMNAME_LIST, QSCHEME, FLOAT_LIST, SCALAR_LIST }; struct FunctionParameter; @@ -158,6 +158,7 @@ struct PythonArgs { inline c10::optional optionalTensor(int i); inline at::Scalar scalar(int i); inline at::Scalar scalarWithDefault(int i, at::Scalar default_scalar); + inline std::vector scalarlist(int i); inline std::vector tensorlist(int i); template inline std::array tensorlist_n(int i); @@ -206,6 +207,7 @@ struct PythonArgs { private: at::Tensor tensor_slow(int i); at::Scalar scalar_slow(int i); + at::Scalar scalar_slow(PyObject* arg); }; struct FunctionParameter { @@ -287,6 +289,19 @@ inline at::Scalar PythonArgs::scalar(int i) { return scalar_slow(i); } +inline std::vector PythonArgs::scalarlist(int i) { + if (!args[i]) return std::vector(); + auto tuple = six::isTuple(args[i]); + THPObjectPtr arg = six::maybeAsTuple(args[i]); + auto size = tuple ? PyTuple_GET_SIZE(arg.get()) : PyList_GET_SIZE(arg.get()); + std::vector res(size); + for (int idx = 0; idx < size; idx++) { + PyObject* obj = tuple ? PyTuple_GET_ITEM(arg.get(), idx) : PyList_GET_ITEM(arg.get(), idx); + res[idx] = scalar_slow(obj); + } + return res; +} + inline at::Scalar PythonArgs::scalarWithDefault(int i, at::Scalar default_scalar) { if (!args[i]) return default_scalar; return scalar_slow(i); From 6b7864462387e14176431efb525f538c1be4d255 Mon Sep 17 00:00:00 2001 From: Bram Wasti Date: Fri, 11 Dec 2020 16:02:04 -0800 Subject: [PATCH 27/33] [te] Add BitCast to the IR (#49184) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49184 Adds BitCasting to NNC. This will enable fast approximation algorithms implemented directly in TensorExpressions Test Plan: buck test mode/no-gpu //caffe2/test/cpp/tensorexpr:tensorexpr Reviewed By: bertmaher Differential Revision: D25466476 fbshipit-source-id: f063ab29ba7bab2dcce463e499f2d4a16bdc1f0e --- test/cpp/tensorexpr/test_llvm.cpp | 83 ++++++++++++++++ test/cpp/tensorexpr/test_type.cpp | 110 +++++++++++++++++++++ torch/csrc/jit/tensorexpr/eval.h | 60 +++++++++++ torch/csrc/jit/tensorexpr/expr.h | 1 + torch/csrc/jit/tensorexpr/ir.h | 29 ++++++ torch/csrc/jit/tensorexpr/ir_mutator.cpp | 9 ++ torch/csrc/jit/tensorexpr/ir_mutator.h | 2 + torch/csrc/jit/tensorexpr/ir_visitor.cpp | 3 + torch/csrc/jit/tensorexpr/ir_visitor.h | 2 + torch/csrc/jit/tensorexpr/llvm_codegen.cpp | 20 ++++ torch/csrc/jit/tensorexpr/loopnest.cpp | 8 ++ 11 files changed, 327 insertions(+) diff --git a/test/cpp/tensorexpr/test_llvm.cpp b/test/cpp/tensorexpr/test_llvm.cpp index 953c184de1fc..c1d3392fff32 100644 --- a/test/cpp/tensorexpr/test_llvm.cpp +++ b/test/cpp/tensorexpr/test_llvm.cpp @@ -160,6 +160,63 @@ TEST(LLVM, ByteToDoubleCastTest) { ASSERT_EQ(cg.value(), 2); } +TEST(LLVM, BitCast) { + constexpr int16_t ref16 = 1337; + constexpr int32_t ref32 = 1337; + constexpr int64_t ref64 = 1337; + at::Half reff16 = 1337.0f; + constexpr float reff32 = 1337.0f; + constexpr double reff64 = 1337.0f; + + // this is broken + /*{ + KernelScope kernel_scope; + at::Half k_; + at::Half* k = &k_; + *reinterpret_cast(k) = ref16; + auto a = HalfImm::make(k); + auto b = BitCast::make(kShort, a); + LLVMExprEval cg(b); + ASSERT_EQ(cg.value(), ref16); + }*/ + + { + KernelScope kernel_scope; + float k = raw_bitcast(ref32); + auto a = FloatImm::make(k); + auto b = BitCast::make(kInt, a); + LLVMExprEval cg(b); + ASSERT_EQ(cg.value(), ref32); + } + + { + KernelScope kernel_scope; + double k = raw_bitcast(ref64); + auto a = DoubleImm::make(k); + auto b = BitCast::make(kLong, a); + LLVMExprEval cg(b); + ASSERT_EQ(cg.value(), ref64); + } + + { + KernelScope kernel_scope; + int64_t k = raw_bitcast(reff64); + auto a = LongImm::make(k); + auto b = BitCast::make(kDouble, a); + LLVMExprEval cg(b); + ASSERT_EQ(cg.value(), reff64); + } + + { + KernelScope kernel_scope; + int32_t k = raw_bitcast(reff32); + auto a = IntImm::make(k); + auto b = BitCast::make(kFloat, a); + LLVMExprEval cg(b); + ASSERT_EQ(cg.value(), reff32); + } +} + TEST(LLVM, LetTest01) { KernelScope kernel_scope; @@ -514,6 +571,32 @@ TEST(LLVM, VectorizerLoadStoreTest) { assertAllEqual(c_vec, 21); } +TEST(LLVM, VectorizeBitCast) { + KernelScope kernel_scope; + Placeholder a(BufHandle("A", {128}, kInt)); + + Tensor* c = Compute("c", {{128, "i"}}, [&](const VarHandle& i) { + return bitcast(a.load(i)); + }); + + Placeholder c_buf(BufHandle(c->buf())); + LoopNest l({c}); + Stmt* s = l.root_stmt(); + l.vectorize(dynamic_cast(s)->front()); + ASSERT_TRUE(dynamic_cast(dynamic_cast(s)->front()) == nullptr); + + LLVMCodeGen cg(s, {a, c_buf}); + + std::vector a_vec(128); + std::vector c_vec(128); + for (auto i = 0; i < 128; ++i) { + a_vec[i] = raw_bitcast(1337.f); + } + std::vector args({a_vec.data(), c_vec.data()}); + ASSERT_EQ(cg.value(args), 0); + assertAllEqual(c_vec, 1337.f); +} + TEST(LLVM, MemcpyTest) { KernelScope kernel_scope; constexpr int N = 32; diff --git a/test/cpp/tensorexpr/test_type.cpp b/test/cpp/tensorexpr/test_type.cpp index 0c771733d935..71ad0f5149ac 100644 --- a/test/cpp/tensorexpr/test_type.cpp +++ b/test/cpp/tensorexpr/test_type.cpp @@ -1,5 +1,6 @@ #include +#include "torch/csrc/jit/tensorexpr/eval.h" #include "torch/csrc/jit/tensorexpr/ir.h" #include "torch/csrc/jit/tensorexpr/tensor.h" @@ -42,6 +43,115 @@ TEST(Type, Test01) { } } +TEST(Type, BitCasting) { + { + KernelScope kernel_scope; + VarHandle x("x", kFloat); + ExprHandle y = bitcast(x); + ASSERT_EQ(y.dtype(), kInt); + } + { + KernelScope kernel_scope; + VarHandle x("x", kInt); + ExprHandle y = bitcast(x); + ASSERT_EQ(y.dtype(), kFloat); + } + { + KernelScope kernel_scope; + VarHandle x("x", kShort); + ExprHandle y = bitcast(x); + ASSERT_EQ(y.dtype(), kHalf); + } + { + KernelScope kernel_scope; + VarHandle x("x", kHalf); + ExprHandle y = bitcast(x); + ASSERT_EQ(y.dtype(), kShort); + } + + constexpr int16_t ref16 = 1337; + constexpr int32_t ref32 = 1337; + constexpr int64_t ref64 = 1337; + at::Half reff16 = 1337.0f; + constexpr float reff32 = 1337.0f; + constexpr double reff64 = 1337.0f; + using SimpleIRExprEval = ExprEval; + // this is broken + /*{ + KernelScope kernel_scope; + at::Half k_; + at::Half* k = &k_; + *reinterpret_cast(k) = ref16; + auto a = HalfImm::make(*k); + auto b = BitCast::make(kShort, a); + SimpleIRExprEval cg(b); + ASSERT_EQ(cg.value(), ref16); + }*/ + + { + KernelScope kernel_scope; + float k = raw_bitcast(ref32); + auto a = FloatImm::make(k); + auto b = BitCast::make(kInt, a); + SimpleIRExprEval cg(b); + ASSERT_EQ(cg.value(), ref32); + } + + { + KernelScope kernel_scope; + double k = raw_bitcast(ref64); + auto a = DoubleImm::make(k); + auto b = BitCast::make(kLong, a); + SimpleIRExprEval cg(b); + ASSERT_EQ(cg.value(), ref64); + } + + { + KernelScope kernel_scope; + int64_t k = raw_bitcast(reff64); + auto a = LongImm::make(k); + auto b = BitCast::make(kDouble, a); + SimpleIRExprEval cg(b); + ASSERT_EQ(cg.value(), reff64); + } + + { + KernelScope kernel_scope; + int32_t k = raw_bitcast(reff32); + auto a = IntImm::make(k); + auto b = BitCast::make(kFloat, a); + SimpleIRExprEval cg(b); + ASSERT_EQ(cg.value(), reff32); + } + + // This segfaults :( + /*{ + KernelScope kernel_scope; + VarHandle x("x", kDouble); + ASSERT_ANY_THROW(ExprHandle y = bitcast(x)); + } + { + KernelScope kernel_scope; + VarHandle x("x", kFloat); + ASSERT_ANY_THROW(ExprHandle y = bitcast(x)); + } + { + KernelScope kernel_scope; + VarHandle x("x", kLong); + ASSERT_ANY_THROW(ExprHandle y = bitcast(x)); + } + { + KernelScope kernel_scope; + VarHandle x("x", kShort); + ASSERT_ANY_THROW(ExprHandle y = bitcast(x)); + } + { + KernelScope kernel_scope; + VarHandle x("x", kInt); + ASSERT_ANY_THROW(ExprHandle y = bitcast(x)); + }*/ +} + TEST(Type, Propagation) { // Same types: { diff --git a/torch/csrc/jit/tensorexpr/eval.h b/torch/csrc/jit/tensorexpr/eval.h index 7b8a4c194782..e7fbd376d563 100644 --- a/torch/csrc/jit/tensorexpr/eval.h +++ b/torch/csrc/jit/tensorexpr/eval.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include @@ -124,6 +125,14 @@ inline c10::Half div_value(c10::Half lhs, c10::Half rhs) { return lhs / rhs; } +template +To raw_bitcast(const From& src) { + TORCH_CHECK(sizeof(To) == sizeof(From), "Invalid bitcast invocation"); + To storage; + std::memcpy(&storage, &src, sizeof(From)); + return reinterpret_cast(storage); +} + class SimpleIREvaluator : public CodeGen, public IRVisitor { public: template @@ -573,6 +582,57 @@ class SimpleIREvaluator : public CodeGen, public IRVisitor { } } + template + std::vector bitcastValues(const Dtype& src_dtype, const Value& v) { + const std::vector& src_values = v.as_vec(); + std::vector dst_values(src_values.size()); + for (int i = 0; i < src_dtype.lanes(); ++i) { + dst_values[i] = raw_bitcast(src_values[i]); + } + return dst_values; + } + + template + void doBitCastFromSrc( + const Dtype& src_dtype, + const Dtype& dst_dtype, + const Value& v) { + switch (dst_dtype.scalar_type()) { +#define DST_TYPE_CASE(Type, Name) \ + case ScalarType::Name: \ + this->value_ = Value(bitcastValues(src_dtype, v)); \ + break; + // bool/half not supported + AT_FORALL_SCALAR_TYPES(DST_TYPE_CASE); +#undef DST_TYPE_CASE + default: + throw unsupported_dtype(); + } + } + + TORCH_API void visit(const BitCast* v) override { + const Expr* src_value = v->src_value(); + src_value->accept(this); + Dtype dst_dtype = v->dtype(); + Dtype src_dtype = src_value->dtype(); + if (src_dtype.byte_size() != dst_dtype.byte_size()) { + throw malformed_input("lane mismatch in Cast", v); + } + if (src_dtype != dst_dtype) { + switch (src_dtype.scalar_type()) { +#define SRC_TYPE_CASE(Type, Name) \ + case ScalarType::Name: \ + doBitCastFromSrc(src_dtype, dst_dtype, value_); \ + break; + // bool/half not supported + AT_FORALL_SCALAR_TYPES(SRC_TYPE_CASE); +#undef SRC_TYPE_CASE + default: + throw unsupported_dtype(); + } + } + } + TORCH_API void visit(const For* v) override { const Expr* var_node = v->var(); v->start()->accept(this); diff --git a/torch/csrc/jit/tensorexpr/expr.h b/torch/csrc/jit/tensorexpr/expr.h index 9b8dd23db0b1..cd05333656c0 100644 --- a/torch/csrc/jit/tensorexpr/expr.h +++ b/torch/csrc/jit/tensorexpr/expr.h @@ -31,6 +31,7 @@ enum IRNodeType { kCompareSelect, kLet, kCast, + kBitCast, kBroadcast, kRamp, kPolynomial, diff --git a/torch/csrc/jit/tensorexpr/ir.h b/torch/csrc/jit/tensorexpr/ir.h index 7eeea564a6a7..6fe4bf0e2ebd 100644 --- a/torch/csrc/jit/tensorexpr/ir.h +++ b/torch/csrc/jit/tensorexpr/ir.h @@ -28,6 +28,7 @@ inline int getPrecedence(IRNodeType ty) { case kPrimitive: return 0; case kCast: + case kBitCast: return 2; case kAdd: case kSub: @@ -81,6 +82,34 @@ ExprHandle cast(const ExprHandle& src_value) { return Cast::make(Dtype(ToDtype(), src_value.dtype().lanes()), src_value); } +// This is a bitwise cast, akin to bitcast in LLVM +class BitCast : public ExprNode { + public: + const Expr* src_value() const { + return src_value_; + } + static ExprHandle make(Dtype dtype, const ExprHandle& src_value) { + return ExprHandle(new BitCast(dtype, src_value.node())); + } + BitCast(Dtype dtype, const Expr* src_value) + : ExprNodeBase(dtype, kBitCast), src_value_(src_value) { + TORCH_CHECK(src_value_->dtype().byte_size() == dtype.byte_size()); + } + + bool isConstant() const override { + return src_value_->isConstant(); + } + + private: + const Expr* src_value_; +}; + +template +ExprHandle bitcast(const ExprHandle& src_value) { + return BitCast::make( + Dtype(ToDtype(), src_value.dtype().lanes()), src_value); +} + // Represent the expression node for binary operators. // A CRTP pattern to share common code among the operators. template diff --git a/torch/csrc/jit/tensorexpr/ir_mutator.cpp b/torch/csrc/jit/tensorexpr/ir_mutator.cpp index 5f0889842b1e..ddbe88bb2c8f 100644 --- a/torch/csrc/jit/tensorexpr/ir_mutator.cpp +++ b/torch/csrc/jit/tensorexpr/ir_mutator.cpp @@ -139,6 +139,15 @@ const Expr* IRMutator::mutate(const Cast* v) { return new Cast(v->dtype(), src_value_new); } +const Expr* IRMutator::mutate(const BitCast* v) { + const Expr* src_value = v->src_value(); + const Expr* src_value_new = src_value->accept_mutator(this); + if (src_value_new == v->src_value()) { + return v; + } + return new BitCast(v->dtype(), src_value_new); +} + const Expr* IRMutator::mutate(const Var* v) { return v; } diff --git a/torch/csrc/jit/tensorexpr/ir_mutator.h b/torch/csrc/jit/tensorexpr/ir_mutator.h index 0913da0e972d..773920cb52fa 100644 --- a/torch/csrc/jit/tensorexpr/ir_mutator.h +++ b/torch/csrc/jit/tensorexpr/ir_mutator.h @@ -26,6 +26,7 @@ AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_DECLARE); #undef IMM_DECLARE class Cast; +class BitCast; class Var; class Buf; class Ramp; @@ -75,6 +76,7 @@ class TORCH_API IRMutator { AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_MUTATE_DECLARE); #undef IMM_MUTATE_DECLARE virtual const Expr* mutate(const Cast* v); + virtual const Expr* mutate(const BitCast* v); virtual const Expr* mutate(const Var* v); virtual const Expr* mutate(const Buf* v); virtual const Expr* mutate(const Ramp* v); diff --git a/torch/csrc/jit/tensorexpr/ir_visitor.cpp b/torch/csrc/jit/tensorexpr/ir_visitor.cpp index ae97a6200d8b..772a28c77add 100644 --- a/torch/csrc/jit/tensorexpr/ir_visitor.cpp +++ b/torch/csrc/jit/tensorexpr/ir_visitor.cpp @@ -79,6 +79,9 @@ AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_VISIT); void IRVisitor::visit(const Cast* v) { v->src_value()->accept(this); } +void IRVisitor::visit(const BitCast* v) { + v->src_value()->accept(this); +} void IRVisitor::visit(const Var* v) {} void IRVisitor::visit(const Ramp* v) { diff --git a/torch/csrc/jit/tensorexpr/ir_visitor.h b/torch/csrc/jit/tensorexpr/ir_visitor.h index 3f5f05229c16..8353da680edb 100644 --- a/torch/csrc/jit/tensorexpr/ir_visitor.h +++ b/torch/csrc/jit/tensorexpr/ir_visitor.h @@ -26,6 +26,7 @@ AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_DECLARE) #undef IMM_DECLARE class Cast; +class BitCast; class Var; class Buf; class Ramp; @@ -74,6 +75,7 @@ class TORCH_API IRVisitor { #undef IMM_PRINT_VISIT virtual void visit(const Cast* v); + virtual void visit(const BitCast* v); virtual void visit(const Var* v); virtual void visit(const Buf* v); virtual void visit(const Ramp* v); diff --git a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp index cb14b9ef4c07..d469a39cf69d 100644 --- a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp +++ b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp @@ -164,6 +164,7 @@ class LLVMCodeGenImpl : public IRVisitor { #undef IMM_VISIT_DECLARE void visit(const Cast* v) override; + void visit(const BitCast* v) override; void visit(const Var* v) override; void visit(const Ramp* v) override; void visit(const Load* v) override; @@ -888,6 +889,25 @@ void LLVMCodeGenImpl::visit(const Cast* v) { } } +void LLVMCodeGenImpl::visit(const BitCast* v) { + v->src_value()->accept(this); + + llvm::Type* dstType = dtypeToLLVM(v->dtype()); + if (v->dtype().lanes() > 1) { + dstType = llvm::VectorType::get(dstType, ElementCount(v->dtype().lanes())); + } + llvm::Type* srcType = dtypeToLLVM(v->src_value()->dtype()); + + if (srcType == dstType) { + // do nothing. + return; + } + + TORCH_CHECK(llvm::CastInst::isBitCastable( + srcType->getScalarType(), dstType->getScalarType())); + value_ = irb_.CreateBitOrPointerCast(value_, dstType); +} + void LLVMCodeGenImpl::visit(const Var* v) { if (varToArg_.count(v)) { auto idx = varToArg_.at(v); diff --git a/torch/csrc/jit/tensorexpr/loopnest.cpp b/torch/csrc/jit/tensorexpr/loopnest.cpp index 0bff2dbf75c7..1598a92ac68c 100644 --- a/torch/csrc/jit/tensorexpr/loopnest.cpp +++ b/torch/csrc/jit/tensorexpr/loopnest.cpp @@ -154,6 +154,14 @@ class Vectorizer : public IRMutator { }); } + const Expr* mutate(const BitCast* v) override { + std::vector inputs = {v->src_value()}; + return try_vectorize(v, inputs, [&]() { + return BitCast::make( + Dtype(v->dtype().scalar_type(), lanes_), ExprHandle(inputs[0])); + }); + } + const Expr* mutate(const Cast* v) override { std::vector inputs = {v->src_value()}; return try_vectorize(v, inputs, [&]() { From 21c38e17997171415a44c6ba578f621037d8ef30 Mon Sep 17 00:00:00 2001 From: Pritam Damania Date: Fri, 11 Dec 2020 17:20:51 -0800 Subject: [PATCH 28/33] Additional validation for DistributedSampler. (#48865) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/48865 If DistributedSampler was provided an invalid rank (ex: https://discuss.pytorch.org/t/distributed-datasets-on-multi-machines/105113), it failed with a cryptic assertion failure. To fix this issue, I've added an additional check to DistributedSampler to validate we provide a valid rank. ghstack-source-id: 117906769 Test Plan: 1) waitforbuildbot 2) Unit test added. Reviewed By: malfet Differential Revision: D25344945 fbshipit-source-id: 7685e00c8b2c200efbd2949fb32ee32ea7232a08 --- test/test_dataloader.py | 9 +++++++++ torch/utils/data/distributed.py | 4 ++++ 2 files changed, 13 insertions(+) diff --git a/test/test_dataloader.py b/test/test_dataloader.py index a1afc216d42a..047297c438b7 100644 --- a/test/test_dataloader.py +++ b/test/test_dataloader.py @@ -1454,6 +1454,15 @@ def test_random_sampler_len_with_replacement(self): self.assertEqual(int(math.ceil(float(num_samples) / batch_size)), count_num_samples_in_data_loader) + def test_distributed_sampler_invalid_rank(self): + from torch.utils.data.distributed import DistributedSampler + dataset = torch.IntTensor(range(10)) + with self.assertRaisesRegex(ValueError, "Invalid rank"): + sampler = DistributedSampler(dataset, 3, 3) + + with self.assertRaisesRegex(ValueError, "Invalid rank"): + sampler = DistributedSampler(dataset, 3, -1) + def test_duplicating_data_with_drop_last(self): from torch.utils.data.distributed import DistributedSampler diff --git a/torch/utils/data/distributed.py b/torch/utils/data/distributed.py index cb67625df518..e048b54a462c 100644 --- a/torch/utils/data/distributed.py +++ b/torch/utils/data/distributed.py @@ -67,6 +67,10 @@ def __init__(self, dataset: Dataset, num_replicas: Optional[int] = None, if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") rank = dist.get_rank() + if rank >= num_replicas or rank < 0: + raise ValueError( + "Invalid rank {}, rank should be in the interval" + " [0, {}]".format(rank, num_replicas - 1)) self.dataset = dataset self.num_replicas = num_replicas self.rank = rank From 29f0fa36b1f78117a60378a8e5df5c284e1e346d Mon Sep 17 00:00:00 2001 From: Yi Wang Date: Fri, 11 Dec 2020 17:43:59 -0800 Subject: [PATCH 29/33] [Gradient Compression] Minor update of the comments on PowerSGD. (#49246) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49246 Previously the comment on matrix_approximation_rank was in PowerSGD_hook function. Now move it into PowerSGDState, because the function arg is already moved to this state as an attribute. Original PR issue: Investigate Applying PowerSGD to Communication Hook for Gradient Compression #47202 ghstack-source-id: 118414247 Test Plan: N/A Reviewed By: rohan-varma Differential Revision: D25501091 fbshipit-source-id: 701e3109a9a3f2a5f9d18d5bf6d0a266518ee8ea --- torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py b/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py index 99ba72cc5868..bbcef98d4214 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py @@ -47,6 +47,8 @@ def __init__( random_seed=0, ): self.process_group = process_group + # The low rank for matrix approximation. + # Typically only 1 or 2 is used. See https://arxiv.org/pdf/1905.13727.pdf. self.matrix_approximation_rank = matrix_approximation_rank # Error feedback is usually crucial for both for convergence and generalization, # because PowerSGD is a biased compressor, @@ -97,8 +99,6 @@ def powerSGD_hook( bucket (dist._GradBucket): Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors. Note that since DDP comm hook only supports single process single device mode at this time, only exactly one tensor is stored in this bucket. - matrix_approximation_rank (int): The low rank for matrix approximation. - Typically only 1 or 2 is used. See https://arxiv.org/pdf/1905.13727.pdf. Returns: Future handler of the communication, which updates the gradients in place. From 76d41c801eca14dbe9ba12399d27ef78ed0b642f Mon Sep 17 00:00:00 2001 From: James Reed Date: Fri, 11 Dec 2020 17:53:04 -0800 Subject: [PATCH 30/33] [JIT] Fix toIValue handling of AttributeError when casting ClassType (#49188) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49188 Test Plan: Imported from OSS Reviewed By: pbelevich Differential Revision: D25476573 Pulled By: jamesr66a fbshipit-source-id: cec296fae71cc0cdf36bde60417d7d3b1aa84198 --- test/jit/test_class_type.py | 20 ++++++++++++++++++++ test/jit/test_torchbind.py | 4 ++++ torch/csrc/jit/python/pybind_utils.h | 9 +++++++++ 3 files changed, 33 insertions(+) diff --git a/test/jit/test_class_type.py b/test/jit/test_class_type.py index b4075dba14c8..a80670f0d22b 100644 --- a/test/jit/test_class_type.py +++ b/test/jit/test_class_type.py @@ -959,6 +959,26 @@ def forward(self, x): # Make sure class constant is accessible from module self.assertEqual(m.w, m_loaded.w) + def test_py_class_to_ivalue_missing_attribute(self): + global Foo # see [local resolution in python] + + class Foo(object): + i : int + f : float + + def __init__(self, i : int, f : float): + self.i = i + self.f = f + + @torch.jit.script + def test_fn(x : Foo) -> float: + return x.i + x.f + + test_fn(Foo(3, 4.0)) + + with self.assertRaisesRegex(RuntimeError, 'missing attribute i'): + test_fn(torch.rand(3, 4)) + def test_unused_method(self): """ Test unused methods on scripted classes. diff --git a/test/jit/test_torchbind.py b/test/jit/test_torchbind.py index af7897e159b3..31eec81d480a 100644 --- a/test/jit/test_torchbind.py +++ b/test/jit/test_torchbind.py @@ -240,6 +240,10 @@ def forward(self): traced = torch.jit.trace(TryTracing(), ()) self.assertEqual(torch.zeros(4, 4), traced()) + def test_torchbind_pass_wrong_type(self): + with self.assertRaisesRegex(RuntimeError, 'missing attribute capsule'): + torch.ops._TorchScriptTesting.take_an_instance(torch.rand(3, 4)) + def test_torchbind_tracing_nested(self): class TryTracingNest(torch.nn.Module): def __init__(self): diff --git a/torch/csrc/jit/python/pybind_utils.h b/torch/csrc/jit/python/pybind_utils.h index dc3b3b13adef..34ca7585be67 100644 --- a/torch/csrc/jit/python/pybind_utils.h +++ b/torch/csrc/jit/python/pybind_utils.h @@ -713,6 +713,15 @@ inline IValue toIValue( const auto& attrType = classType->getAttribute(slot); const auto& attrName = classType->getAttributeName(slot); + if (!py::hasattr(obj, attrName.c_str())) { + throw py::cast_error(c10::str( + "Tried to cast object to type ", + type->repr_str(), + " but object", + " was missing attribute ", + attrName)); + } + const auto& contained = py::getattr(obj, attrName.c_str()); userObj->setSlot(slot, toIValue(contained, attrType)); } From 635f1cd1a57d10d381ff043689281ea578445744 Mon Sep 17 00:00:00 2001 From: Venkata Chintapalli Date: Fri, 11 Dec 2020 17:56:45 -0800 Subject: [PATCH 31/33] Enable LayerNorm test cases Summary: Remove Skip from test defs. Test Plan: https://our.intern.facebook.com/intern/testinfra/testrun/1407375060598951 Reviewed By: hyuen Differential Revision: D25513174 fbshipit-source-id: 0ddfd1713cf7b9daf25f6e62df92d682cade350f --- caffe2/contrib/fakelowp/test/test_layernorm_nnpi_fp16.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/caffe2/contrib/fakelowp/test/test_layernorm_nnpi_fp16.py b/caffe2/contrib/fakelowp/test/test_layernorm_nnpi_fp16.py index 36d6ba73e0c3..f992c6f9e1fc 100644 --- a/caffe2/contrib/fakelowp/test/test_layernorm_nnpi_fp16.py +++ b/caffe2/contrib/fakelowp/test/test_layernorm_nnpi_fp16.py @@ -27,7 +27,7 @@ class LayerNorm(serial.SerializedTestCase): epsilon=st.floats(min_value=1e-4, max_value=1e-3), elementwise_affine=st.booleans()) @settings(deadline=datetime.timedelta(seconds=10)) - def Skip_test_layernorm(self, seed, batch_size, size, epsilon, elementwise_affine): + def test_layernorm(self, seed, batch_size, size, epsilon, elementwise_affine): np.random.seed(seed) # Reset the workspace workspace.ResetWorkspace() @@ -142,7 +142,7 @@ def _layernorm_transform(self, X): elementwise_affine=st.booleans()) @settings(deadline=datetime.timedelta(seconds=10)) # re-enable when T74553975 gets fixed - def Skip_test_fused_ln_quantize(self, seed, batch_size, size, epsilon, elementwise_affine): + def test_fused_ln_quantize(self, seed, batch_size, size, epsilon, elementwise_affine): np.random.seed(seed) # Reset the workspace From 8d58362f59edb149fcee691ffca03ecdd94066fe Mon Sep 17 00:00:00 2001 From: Martin Yuan Date: Fri, 11 Dec 2020 18:49:27 -0800 Subject: [PATCH 32/33] [PyTorch] Remove native::zeros reference in TensorIndexing (#49117) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49117 Try to resolve part of the github issue of https://github.com/pytorch/pytorch/issues/48684 . It essentially calls the same functionality inside at::native::zeros(). After this diff, all references to aten::native symbols are removed. ghstack-source-id: 118261305 Test Plan: CI Reviewed By: dhruvbird Differential Revision: D25444940 fbshipit-source-id: 7f782680daa3aedd1b7301cb08576da2ec70c188 --- aten/src/ATen/TensorIndexing.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aten/src/ATen/TensorIndexing.h b/aten/src/ATen/TensorIndexing.h index a4c0a0b31c34..162efd1c6c8a 100644 --- a/aten/src/ATen/TensorIndexing.h +++ b/aten/src/ATen/TensorIndexing.h @@ -227,7 +227,7 @@ static inline Tensor applySelect( static inline Tensor boolToIndexingTensorCPUOrCUDA(const Tensor& self, bool value) { // booleans add a dimension of size 1. true indexes this dimension as if 0:, false as empty. if (value) { - return at::native::zeros({1}, {}, self.options().dtype(kLong)); + return at::empty({1}, {}, self.options().dtype(kLong)).fill_(0.); } else { return at::empty({0}, {}, self.options().dtype(kLong)); } From 693e9086561e6badadcb0aeda712a7300c876983 Mon Sep 17 00:00:00 2001 From: Chunli Fu Date: Fri, 11 Dec 2020 19:38:47 -0800 Subject: [PATCH 33/33] [shape inference] fix ConstantFill Test Plan: unit test Reviewed By: yinghai Differential Revision: D25326529 fbshipit-source-id: 1322635567f6661637cde90cadaac0197975e133 --- caffe2/opt/bound_shape_inference_test.cc | 24 ++++++++++++++++++++++++ caffe2/opt/bound_shape_inferencer.cc | 6 ++++++ 2 files changed, 30 insertions(+) diff --git a/caffe2/opt/bound_shape_inference_test.cc b/caffe2/opt/bound_shape_inference_test.cc index 95302ca5ccc4..f9c9b6acf034 100644 --- a/caffe2/opt/bound_shape_inference_test.cc +++ b/caffe2/opt/bound_shape_inference_test.cc @@ -270,6 +270,30 @@ TEST(BoundShapeInference, LengthsRangeFill) { TensorProto_DataType_INT32); } + +TEST(BoundShapeInference, ConstantFill) { + NetDef net; + net.add_op()->CopyFrom( + CreateOperatorDef("ConstantFill", "", {"X"}, {"Y"}, {})); + ShapeInfoMap shape_map; + BoundShapeSpec spec(20, 1000); + BoundShapeInferencer eng(spec); + shape_map.emplace( + "X", + makeTensorInfo( + {TensorBoundShape_DimType_BATCH, + TensorBoundShape_DimType_CONSTANT}, + {20, 1024})); + eng.InferBoundShapeAndType(net, shape_map, nullptr); + const auto& out_shape = eng.shape_info(); + verifyShapeInfo( + out_shape, + "Y", + {TensorBoundShape_DimType_BATCH, TensorBoundShape_DimType_CONSTANT}, + {20, 1024}, + TensorProto_DataType_FLOAT); +} + // https://github.com/pytorch/pytorch/issues/40861 TEST(BoundShapeInference, DISABLED_ON_WINDOWS(Reshape)) { NetDef net; diff --git a/caffe2/opt/bound_shape_inferencer.cc b/caffe2/opt/bound_shape_inferencer.cc index c513c1a37b01..8ef5de06b02e 100644 --- a/caffe2/opt/bound_shape_inferencer.cc +++ b/caffe2/opt/bound_shape_inferencer.cc @@ -322,6 +322,12 @@ void BoundShapeInferencer::InferGivenTensorFill(const OperatorDef& op) { if (it != shape_info_.end()) { it->second.setDimType(std::vector( it->second.shape.dims_size(), TensorBoundShape_DimType_CONSTANT)); + if (op.type() == "ConstantFill" && op.input_size() >= 1) { + auto it_input = shape_info_.find(op.input(0)); + if (it_input != shape_info_.end()) { + it->second.setDimType(it_input->second.getDimType()); + } + } } }