From 64681d6beccb552cbcf62a12ea2b877ab2b69069 Mon Sep 17 00:00:00 2001 From: Yanan Cao Date: Tue, 6 Oct 2020 12:34:14 -0700 Subject: [PATCH 01/69] Add all remaining method declarations from torch.distributed Python API to C++ (#45768) Summary: Also ran formatter on previous sections Pull Request resolved: https://github.com/pytorch/pytorch/pull/45768 Reviewed By: wanchaol Differential Revision: D24129467 Pulled By: gmagogsfm fbshipit-source-id: aa8a5c45c3609d5b96e5f585b699d9e3e71394c8 --- torch/csrc/distributed/c10d/c10d_frontend.h | 276 +++++++++++++++----- 1 file changed, 208 insertions(+), 68 deletions(-) diff --git a/torch/csrc/distributed/c10d/c10d_frontend.h b/torch/csrc/distributed/c10d/c10d_frontend.h index 9ff4b69999c7..816c8d9fe473 100644 --- a/torch/csrc/distributed/c10d/c10d_frontend.h +++ b/torch/csrc/distributed/c10d/c10d_frontend.h @@ -1,86 +1,226 @@ #pragma once -#include -#include #include #include +#include +#include +#include +#include +#include #include #include -#include -#include namespace c10d { class Backend { - public: - // Maps to Backend.__new__ in Python. - static std::string get(std::string); + public: + // Maps to Backend.__new__ in Python. + static std::string get(std::string); - // TODO: How to support registering third_party backend? - static void registerBackend(); + // TODO: How to support registering third_party backend? + static void registerBackend(); - private: - // TODO: Should this be an enum list instead since this set doesn't - // change at all. - std::unordered_set registered_backends_; + private: + // TODO: Should this be an enum list instead since this set doesn't + // change at all. + std::unordered_set registered_backends_; }; -class DistributedC10d{ - public: - void initProcessGroup( - const std::string& backend, - const std::string& init_method, - const std::chrono::milliseconds& timeout, - int64_t world_size, - int64_t rank, - std::shared_ptr store, - const std::string& group_name); - - void destroyProcessGroup(std::shared_ptr group); - int64_t getRank(std::shared_ptr group); - int64_t getWorldSize(std::shared_ptr group); - - ProcessGroup::Work isend(at::Tensor tensor, int64_t dst, std::shared_ptr group, c10::optional tag); - ProcessGroup::Work irecv(at::Tensor tensor, int64_t src, std::shared_ptr group, c10::optional tag); - - private: - DistributedC10d(){}; - - bool rankNotInGroup(std::shared_ptr group) const; - int64_t getGroupRank( - std::shared_ptr group, - const int64_t rank) const; - int64_t getGlobalRank( - std::shared_ptr group, - const int64_t global_rank) const; - void checkDefaultPg() const; - int64_t getGroupSize(std::shared_ptr group) const; - int64_t getBackend(std::shared_ptr group); - - std::string backend_; - // TODO: Ask Alex what kind of equality we need. It determine whether we - // need to use ProcessGroup or ProcesGroup* as key. - std::unordered_map< - std::shared_ptr, - std::pair, std::shared_ptr>> - pg_map_; - - // Note, this is different mapping relationship than original Python - // implementation. - std::unordered_map, std::string> pg_names_; - - // Value is global_rank:group_rank mapping. - std::unordered_map, std::vector> - pg_group_ranks_; - - std::shared_ptr default_pg_; - - // Default value should be "env://" - std::string default_pg_init_method_; - - int64_t group_count_; +class DistributedC10d { + public: + void initProcessGroup( + const std::string& backend, + const std::string& init_method, + const std::chrono::milliseconds& timeout, + int64_t world_size, + int64_t rank, + std::shared_ptr store, + const std::string& group_name); + + void destroyProcessGroup(std::shared_ptr group); + int64_t getRank(std::shared_ptr group); + int64_t getWorldSize(std::shared_ptr group); + + ProcessGroup::Work isend( + at::Tensor tensor, + int64_t dst, + std::shared_ptr group, + c10::optional tag); + + ProcessGroup::Work irecv( + at::Tensor tensor, + int64_t src, + std::shared_ptr group, + c10::optional tag); + + ProcessGroup::Work send( + at::Tensor tensor, + int64_t dst, + std::shared_ptr group, + c10::optional tag); + + ProcessGroup::Work recv( + at::Tensor tensor, + int64_t src, + std::shared_ptr group, + c10::optional tag); + + c10::optional broadcastMultiGPU( + std::vector tensor_list, + int64_t src, + std::shared_ptr group, + bool async_op, + int64_t src_tensor); + + c10::optional broadcast( + at::Tensor tensor, + int64_t src, + std::shared_ptr group, + bool async_op); + + c10::optional allReduceMultiGPU( + std::vector& tensor_list, + ReduceOp op, + std::shared_ptr group, + bool async_op); + + c10::optional allReduce( + at::Tensor tensor, + ReduceOp op, + std::shred_ptr group, + bool async_op); + + c10::optional allReduceCoalesced( + at::Tensor tensor, + ReduceOp op, + std::shred_ptr group, + bool async_op); + + c10::optional reduceMultiGPU( + std::vector& tensor_list, + int64_t dst, + ReduceOp op, + std::shared_ptr group, + bool async_op, + int64_t dst_tensor); + + c10::optional reduce( + at::Tensor tensor, + int64_t dst, + ReduceOp op, + std::shared_ptr& group, + bool async_op); + + c10::optional allGatherMultiGPU( + std::vector>& output_tensor_lists, + const std::vector& input_tensor_list, + std::shared_ptr group, + bool async_op); + + // TODO TODO following APIs take python objects and unpickle them, how do we support these? + // ProcessGroup::Work allGatherObject() + // ProcessGroup::Work gatherObject() + // ProcessGroup::Work broadcastObjectList() + + c10::optional allGather( + std::vector& tensor_list, + at::Tensor tensor, + std::shared_ptr group, + bool async_op); + + c10::optional allGatherCoalesced( + std::vector>& output_tensor_lists, + std::vector& input_tensor_list, + std::shared_ptr group, + bool async_op); + + c10::optional gather( + at::Tensor tensor, + std::vector& gather_list, + int64_t dst, + std::shared_ptr group, + bool async_op); + + c10::optional scatter( + at::Tensor tensor, + std::vector& scatter_list, + int64_t dst, + std::shared_ptr group, + bool async_op); + + ProcessGroup::Work reduceScatterMultiGPU( + std::vector& output_tensor_list, + const std::vector>& input_tensor_lists, + ReduceOp op, + std::shared_ptr group, + bool async_op); + + ProcessGroup::Work reduceScatter( + at::Tensor output, + const std::vector& input_list, + ReduceOp op, + std::shared_ptr group, + bool async_op); + + ProcessGroup::Work allToAllSingle( + at::Tensor output, + at::Tensor input, + const std::vector& output_split_sizes, + const std::vector& input_split_sizes, + std::shared_ptr group, + bool async_op); + + ProcessGroup::Work allToAll( + std::vector& output_tensor_list, + const std::vector& input_tensor_list, + std::shared_ptr group, + bool async_op); + + ProcessGroup::Work barrier( + std::shared_ptr group, + bool async_op); + + std::shared_ptr newGroup( + std::vector ranks, + std::chrono::milliseconds timeout, + Backend backend); + + private: + DistributedC10d(){}; + + bool rankNotInGroup(std::shared_ptr group) const; + int64_t getGroupRank(std::shared_ptr group, const int64_t rank) + const; + int64_t getGlobalRank( + std::shared_ptr group, + const int64_t global_rank) const; + void checkDefaultPg() const; + int64_t getGroupSize(std::shared_ptr group) const; + int64_t getBackend(std::shared_ptr group); + + std::string backend_; + // TODO: Ask Alex what kind of equality we need. It determine whether we + // need to use ProcessGroup or ProcesGroup* as key. + std::unordered_map< + std::shared_ptr, + std::pair, std::shared_ptr>> + pg_map_; + + // Note, this is different mapping relationship than original Python + // implementation. + std::unordered_map, std::string> pg_names_; + + // Value is global_rank:group_rank mapping. + std::unordered_map, std::vector> + pg_group_ranks_; + + std::shared_ptr default_pg_; + + // Default value should be "env://" + std::string default_pg_init_method_; + + int64_t group_count_; }; - } // namespace c10d From 3fbddb92b1be1f70edced886745116b8daeebb17 Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Tue, 6 Oct 2020 12:57:01 -0700 Subject: [PATCH 02/69] caffe2/plan_executor: wait for 1 minute after exception and then abort (#45297) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45297 If we have two concurrent substeps and one of them throws an exception and the other is blocking, we'll currently hang. This waits up to 1 minute for it to complete before terminating the process. Test Plan: buck test caffe2/caffe2:caffe2_test_cpu -- PlanExecutorTest --stress-runs 100 Reviewed By: dahsh Differential Revision: D20850851 fbshipit-source-id: 330503775d8062a34645ba55fe38e6770de5e3c7 --- caffe2/core/plan_executor.cc | 63 +++++++++++++++++++++++++++++++ caffe2/core/plan_executor_test.cc | 59 ++++++++++++++++++++++++++++- 2 files changed, 121 insertions(+), 1 deletion(-) diff --git a/caffe2/core/plan_executor.cc b/caffe2/core/plan_executor.cc index 3f70e96fffc8..c7c0200e5880 100644 --- a/caffe2/core/plan_executor.cc +++ b/caffe2/core/plan_executor.cc @@ -17,10 +17,18 @@ C10_DEFINE_bool( "If used we will handle exceptions in executor threads. " "This avoids SIGABRT but may cause process to deadlock"); +C10_DEFINE_int( + caffe2_plan_executor_exception_timeout, + 60, + "Number of seconds to wait for concurrent threads to stop on exception" + "before terminating."); + namespace caffe2 { namespace { +// ExceptionWrapper holds an exception. If exception pointers are being used, +// it'll hold the original exception pointer otherwise just the message. class ExceptionWrapper { public: ExceptionWrapper() : hasException_(false) {} @@ -39,6 +47,10 @@ class ExceptionWrapper { #endif } + const std::string& what() const { + return exceptionMsg_; + } + operator bool() { return hasException_; } @@ -51,6 +63,33 @@ class ExceptionWrapper { std::string exceptionMsg_; }; +// ExceptionWrapperTerminate terminates the program with the specified +// exception. This preserves the exception ptr and ExceptionTracer will +// correctly grab it on exit. +class ExceptionWrapperTerminate { + public: + explicit ExceptionWrapperTerminate(ExceptionWrapper&& ew) : ew_(std::move(ew)) {} + + ~ExceptionWrapperTerminate() { + ew_.rethrowException(); + } + + private: + ExceptionWrapper ew_; +}; + +// ScopeExitGuard runs the provided function when it's destructed. +class ScopeExitGuard { + public: + explicit ScopeExitGuard(std::function&& f) : f_(std::move(f)) {} + ~ScopeExitGuard() { + f_(); + } + + private: + std::function f_; +}; + struct NetDefInfo { const NetDef* netDef; // in order to keep the "override existing nets" on the top-level workflow, @@ -460,9 +499,16 @@ bool ExecuteStepRecursive(ExecutionStepWrapper& stepWrapper) { << " with " << step.substep().size() << " concurrent substeps"; std::atomic next_substep{0}; + std::condition_variable cv; + std::atomic done{0}; std::mutex exception_mutex; ExceptionWrapper first_exception; auto worker = [&]() { + ScopeExitGuard on_exit([&] { + done += 1; + cv.notify_all(); + }); + auto num_substeps = compiledStep->recurringSubsteps.size(); int substep_id = next_substep++ % num_substeps; if (compiledStep->gotFailure) { @@ -500,6 +546,23 @@ bool ExecuteStepRecursive(ExecutionStepWrapper& stepWrapper) { for (size_t i = 0; i < numThreads; ++i) { threads.emplace_back(worker); } + + auto workersDone = [&] { return done == numThreads; }; + + // If we get an exception, try to wait for all threads to stop + // gracefully. + std::unique_lock guard(exception_mutex); + cv.wait(guard, [&] { return workersDone() || first_exception; }); + cv.wait_for( + guard, + std::chrono::seconds(FLAGS_caffe2_plan_executor_exception_timeout), + [&] { return workersDone(); }); + if (!workersDone() && first_exception) { + LOG(ERROR) << "failed to stop concurrent workers after exception: " + << first_exception.what(); + ExceptionWrapperTerminate(std::move(first_exception)); + } + for (auto& thread : threads) { thread.join(); } diff --git a/caffe2/core/plan_executor_test.cc b/caffe2/core/plan_executor_test.cc index 86f145d72a09..1b0eb0e718a2 100644 --- a/caffe2/core/plan_executor_test.cc +++ b/caffe2/core/plan_executor_test.cc @@ -67,6 +67,29 @@ class ErrorOp final : public Operator { REGISTER_CPU_OPERATOR(Error, ErrorOp); OPERATOR_SCHEMA(Error).NumInputs(0).NumOutputs(0); +static std::atomic blockingErrorRuns{0}; +class BlockingErrorOp final : public Operator { + public: + BlockingErrorOp(const OperatorDef& operator_def, Workspace* ws) + : Operator(operator_def, ws) {} + + bool RunOnDevice() override { + // First n op executions should block and then start throwing errors. + if (blockingErrorRuns.fetch_sub(1) >= 1) { + LOG(INFO) << "blocking"; + while (true) { + std::this_thread::sleep_for(std::chrono::hours(10)); + } + } else { + LOG(INFO) << "throwing"; + throw TestError(); + } + } +}; + +REGISTER_CPU_OPERATOR(BlockingError, BlockingErrorOp); +OPERATOR_SCHEMA(BlockingError).NumInputs(0).NumOutputs(0); + PlanDef parallelErrorPlan() { PlanDef plan_def; @@ -101,10 +124,12 @@ PlanDef parallelErrorPlan() { } struct HandleExecutorThreadExceptionsGuard { - HandleExecutorThreadExceptionsGuard() { + HandleExecutorThreadExceptionsGuard(int timeout = 60) { globalInit({ "caffe2", "--caffe2_handle_executor_threads_exceptions=1", + "--caffe2_plan_executor_exception_timeout=" + + caffe2::to_string(timeout), }); } @@ -139,6 +164,38 @@ TEST(PlanExecutorTest, ErrorAsyncPlan) { ASSERT_EQ(cancelCount, 1); } +TEST(PlanExecutorTest, BlockingErrorPlan) { + ASSERT_DEATH( + [] { + HandleExecutorThreadExceptionsGuard guard(/*timeout=*/1); + + PlanDef plan_def; + + std::string plan_def_template = R"DOC( + network { + name: "net" + op { + type: "BlockingError" + } + } + execution_step { + num_concurrent_instances: 2 + substep { + network: "net" + } + } + )DOC"; + + CAFFE_ENFORCE( + TextFormat::ParseFromString(plan_def_template, &plan_def)); + Workspace ws; + blockingErrorRuns = 1; + ws.RunPlan(plan_def); + FAIL() << "shouldn't have reached this point"; + }(), + "failed to stop concurrent workers after exception: test error"); +} + } // namespace caffe2 #endif From c1af91a13aa82661d0d15ed467b9c68dc36b914b Mon Sep 17 00:00:00 2001 From: n-v-k <71945655+n-v-k@users.noreply.github.com> Date: Tue, 6 Oct 2020 13:14:58 -0700 Subject: [PATCH 03/69] [caffe2] SliceOp axes indexing fixes. (#45432) Summary: Fixes https://github.com/pytorch/pytorch/issues/45431 Pull Request resolved: https://github.com/pytorch/pytorch/pull/45432 Reviewed By: albanD Differential Revision: D24132547 Pulled By: dzhulgakov fbshipit-source-id: d67f7a92d806fb8ac8fc8f522b251d3a8fb83037 --- caffe2/operators/slice_op.cc | 9 +++++---- caffe2/operators/slice_op.cu | 19 ++++++++++--------- caffe2/operators/slice_op.h | 23 ++++++++++++----------- 3 files changed, 27 insertions(+), 24 deletions(-) diff --git a/caffe2/operators/slice_op.cc b/caffe2/operators/slice_op.cc index 7acf854ba9da..f9fd39303261 100644 --- a/caffe2/operators/slice_op.cc +++ b/caffe2/operators/slice_op.cc @@ -17,7 +17,7 @@ Produces a slice of the input tensor. - Start and end indices are either passed as two 1D input tensors or using the `starts` and `ends` arguments. -- If a negative value is passed for any of the start or end indices, it represents the number of elements before the end of that dimension. End indices are non-inclusive unless negative (end index -1 means up to and including the last element). +- If a negative value is passed for any of the start or end indices, it represents |value| - 1 elements before the end of that dimension. End indices are non-inclusive unless negative (end index -1 means up to and including the last element). Github Links: - https://github.com/pytorch/pytorch/blob/master/caffe2/operators/slice_op.cc @@ -67,11 +67,11 @@ print("Y:", workspace.FetchBlob("Y")) .Input( 1, "starts", - "(*Tensor``*): 1D tensor of start-indices for each dimension of data") + "(*Tensor``*): 1D tensor of start-indices for each dimension of data (dimensions following the sliced one might be omitted)") .Input( 2, "ends", - "(*Tensor``*): 1D tensor of end-indices for each dimension of data") + "(*Tensor``*): 1D tensor of end-indices for each dimension of data (dimensions following the sliced one might be omitted)") .Arg("starts", "(*Tuple(int)*): list of starting indices") .Arg("ends", "(*Tuple(int)*): list of ending indices") .TensorInferenceFunction([](const OperatorDef& def, @@ -90,9 +90,10 @@ print("Y:", workspace.FetchBlob("Y")) for (int i = 0; i < data.dims_size(); ++i) { if (i >= starts.size()) { + dst_sizes[i] = data.dims(i); continue; } - if (data.dims_size() > 0) { + if (data.dims(i) > 0) { auto start = starts[i]; auto end = ends[i]; if (start < 0) { diff --git a/caffe2/operators/slice_op.cu b/caffe2/operators/slice_op.cu index 7a843fee3a52..184385310c9c 100644 --- a/caffe2/operators/slice_op.cu +++ b/caffe2/operators/slice_op.cu @@ -74,22 +74,23 @@ bool SliceImplGpu( if (i >= starts.numel()) { starts_idx[i] = 0; ends_idx[i] = data.size(i); + dst_sizes[i] = data.size(i); continue; } if (data.size(i) > 0) { auto start = starts_data[i]; auto end = ends_data[i]; if (start < 0) { - start = data.sizes()[i] + 1 + start; + start = data.size(i) + 1 + start; } if (end < 0) { - end = data.sizes()[i] + 1 + end; + end = data.size(i) + 1 + end; } - if (start > data.sizes()[i]) { - start = data.sizes()[i]; + if (start > data.size(i)) { + start = data.size(i); } - if (end > data.sizes()[i]) { - end = data.sizes()[i]; + if (end > data.size(i)) { + end = data.size(i); } CAFFE_ENFORCE_GE(start, 0); CAFFE_ENFORCE_GE(end, 0); @@ -115,7 +116,7 @@ bool SliceImplGpu( // for now only supports slicing in 1 dimension int dim = -1; for (int i = 0; i < data.dim(); ++i) { - if (starts_idx[i] > 0 || ends_idx[i] < data.sizes()[i]) { + if (starts_idx[i] > 0 || ends_idx[i] < data.size(i)) { CAFFE_ENFORCE_EQ( dim, -1, "Currently only possible to slice in 1 dimension."); dim = i; @@ -154,7 +155,7 @@ bool SliceImplGpu( size_t src_nbytes = data.nbytes(); size_t dst_nbytes = output->nbytes(); - size_t src_block_size = unit * data.sizes()[dim]; + size_t src_block_size = unit * data.size(dim); size_t dst_block_size = unit * (ends_idx[dim] - starts_idx[dim]); size_t src_offset = unit * starts_idx[dim]; @@ -187,7 +188,7 @@ bool SliceImplGpu( size_t dst_nbytes = gdata->nbytes(); size_t src_block_size = unit * (ends_idx[dim] - starts_idx[dim]); - size_t dst_block_size = unit * data.sizes()[dim]; + size_t dst_block_size = unit * data.size(dim); size_t dst_offset = unit * starts_idx[dim]; if (num_blocks == 0 || dst_block_size == 0) { diff --git a/caffe2/operators/slice_op.h b/caffe2/operators/slice_op.h index 8d1990e54c38..9706472315b6 100644 --- a/caffe2/operators/slice_op.h +++ b/caffe2/operators/slice_op.h @@ -33,23 +33,24 @@ bool SliceImpl( for (int i = 0; i < data.dim(); ++i) { if (i >= starts.numel()) { starts_idx[i] = 0; - ends_idx[i] = data.sizes()[i]; + ends_idx[i] = data.size(i); + dst_sizes[i] = data.size(i); continue; } - if (data.sizes()[i] > 0) { + if (data.size(i) > 0) { auto start = starts_data[i]; auto end = ends_data[i]; if (start < 0) { - start = data.sizes()[i] + 1 + start; + start = data.size(i) + 1 + start; } if (end < 0) { - end = data.sizes()[i] + 1 + end; + end = data.size(i) + 1 + end; } - if (start > data.sizes()[i]) { - start = data.sizes()[i]; + if (start > data.size(i)) { + start = data.size(i); } - if (end > data.sizes()[i]) { - end = data.sizes()[i]; + if (end > data.size(i)) { + end = data.size(i); } CAFFE_ENFORCE_GE(start, 0); CAFFE_ENFORCE_GE(end, 0); @@ -78,7 +79,7 @@ bool SliceImpl( // for now only supports slicing in 1 dimension int dim = -1; for (int i = 0; i < data.dim(); ++i) { - if (starts_idx[i] > 0 || ends_idx[i] < data.sizes()[i]) { + if (starts_idx[i] > 0 || ends_idx[i] < data.size(i)) { CAFFE_ENFORCE_EQ( dim, -1, "Currently only possible to slice in 1 dimension."); dim = i; @@ -117,7 +118,7 @@ bool SliceImpl( size_t src_nbytes = data.nbytes(); size_t dst_nbytes = output->nbytes(); - size_t src_block_size = unit * data.sizes()[dim]; + size_t src_block_size = unit * data.size(dim); size_t dst_block_size = unit * (ends_idx[dim] - starts_idx[dim]); size_t src_offset = unit * starts_idx[dim]; @@ -155,7 +156,7 @@ bool SliceImpl( size_t dst_nbytes = gdata->nbytes(); size_t src_block_size = unit * (ends_idx[dim] - starts_idx[dim]); - size_t dst_block_size = unit * data.sizes()[dim]; + size_t dst_block_size = unit * data.size(dim); size_t dst_offset = unit * starts_idx[dim]; if (num_blocks == 0 || dst_block_size == 0) { From a69a78daa2a53e2c4b1088b3acc4ed340aca3c2c Mon Sep 17 00:00:00 2001 From: Sam Estep Date: Tue, 6 Oct 2020 13:20:14 -0700 Subject: [PATCH 04/69] Use smaller N to speed up TestForeach (#45785) Summary: Between September 25 and September 27, approximately half an hour was added to the running time of `pytorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_test`. Judging from the CircleCI data, it looks like the majority of the new time was added by the following PRs: - https://github.com/pytorch/pytorch/issues/44550 - https://github.com/pytorch/pytorch/issues/45298 I'm not sure what to do about https://github.com/pytorch/pytorch/issues/44550, but it looks like https://github.com/pytorch/pytorch/issues/45298 increased the `N` for `TestForeach` from just 20 to include both 30 and 300. This PR would remove the 300, decreasing the test time by a couple orders of magnitude (at least when running it on my devserver), from over ten minutes to just a few seconds. Pull Request resolved: https://github.com/pytorch/pytorch/pull/45785 Reviewed By: malfet Differential Revision: D24094782 Pulled By: samestep fbshipit-source-id: 2476cee9d513b2b07bc384de751e08d0e5d8b5e7 --- test/test_foreach.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/test/test_foreach.py b/test/test_foreach.py index 683b4fe28167..7f19c17c2558 100644 --- a/test/test_foreach.py +++ b/test/test_foreach.py @@ -1,8 +1,10 @@ import torch import unittest -from torch.testing._internal.common_utils import TestCase, run_tests, TEST_WITH_ROCM +from torch.testing._internal.common_utils import TestCase, run_tests, TEST_WITH_ROCM, TEST_WITH_SLOW from torch.testing._internal.common_device_type import instantiate_device_type_tests, dtypes, skipCUDAIfRocm +N_values = [20] if not TEST_WITH_SLOW else [30, 300] + class TestForeach(TestCase): foreach_bin_ops = [ torch._foreach_add, @@ -50,7 +52,7 @@ def _get_test_data(self, device, dtype, N): return tensors def _test_bin_op_list(self, device, dtype, foreach_op, foreach_op_, torch_op): - for N in [30, 300]: + for N in N_values: tensors1 = self._get_test_data(device, dtype, N) tensors2 = self._get_test_data(device, dtype, N) @@ -68,7 +70,7 @@ def _test_bin_op_list(self, device, dtype, foreach_op, foreach_op_, torch_op): self.assertEqual(tensors1, expected) def _test_unary_op(self, device, dtype, foreach_op, foreach_op_, torch_op): - for N in [30, 300]: + for N in N_values: tensors1 = self._get_test_data(device, dtype, N) # Mimics cuda kernel dtype flow. With fp16/bf16 input, runs in fp32 and casts output back to fp16/bf16. control_dtype = torch.float32 if (self.device_type == 'cuda' and @@ -83,7 +85,7 @@ def _test_unary_op(self, device, dtype, foreach_op, foreach_op_, torch_op): self.assertEqual(tensors1, expected) def _test_pointwise_op(self, device, dtype, foreach_op, foreach_op_, torch_op): - for N in [30, 300]: + for N in N_values: tensors = self._get_test_data(device, dtype, N) tensors1 = self._get_test_data(device, dtype, N) tensors2 = self._get_test_data(device, dtype, N) @@ -174,7 +176,7 @@ def test_addcdiv(self, device, dtype): # @dtypes(*torch.testing.get_all_dtypes()) def test_int_scalar(self, device, dtype): - for N in [30, 300]: + for N in N_values: for foreach_bin_op, foreach_bin_op_, torch_bin_op in zip(self.foreach_bin_ops, self.foreach_bin_ops_, self.torch_bin_ops): @@ -215,7 +217,7 @@ def test_int_scalar(self, device, dtype): # Current schema is using 'float[]' as scalar list type. @dtypes(*torch.testing.get_all_dtypes()) def test_int_scalarlist(self, device, dtype): - for N in [30, 300]: + for N in N_values: for foreach_bin_op, foreach_bin_op_, torch_bin_op in zip(self.foreach_bin_ops_sl, self.foreach_bin_ops_sl_, self.torch_bin_ops): @@ -260,7 +262,7 @@ def test_int_scalarlist(self, device, dtype): @dtypes(*torch.testing.get_all_dtypes()) def test_float_scalar(self, device, dtype): - for N in [30, 300]: + for N in N_values: for foreach_bin_op, foreach_bin_op_, torch_bin_op in zip(self.foreach_bin_ops, self.foreach_bin_ops_, self.torch_bin_ops): @@ -303,7 +305,7 @@ def test_float_scalar(self, device, dtype): @dtypes(*torch.testing.get_all_dtypes()) def test_float_scalarlist(self, device, dtype): - for N in [30, 300]: + for N in N_values: for foreach_bin_op, foreach_bin_op_, torch_bin_op in zip(self.foreach_bin_ops_sl, self.foreach_bin_ops_sl_, self.torch_bin_ops): @@ -363,7 +365,7 @@ def test_float_scalarlist(self, device, dtype): @dtypes(*torch.testing.get_all_dtypes()) def test_complex_scalar(self, device, dtype): - for N in [30, 300]: + for N in N_values: for foreach_bin_op, foreach_bin_op_, torch_bin_op in zip(self.foreach_bin_ops, self.foreach_bin_ops_, self.torch_bin_ops): @@ -401,7 +403,7 @@ def test_complex_scalar(self, device, dtype): @dtypes(*torch.testing.get_all_dtypes()) def test_complex_scalarlist(self, device, dtype): - for N in [30, 300]: + for N in N_values: for foreach_bin_op, foreach_bin_op_, torch_bin_op in zip(self.foreach_bin_ops_sl, self.foreach_bin_ops_sl_, self.torch_bin_ops): @@ -426,7 +428,7 @@ def test_complex_scalarlist(self, device, dtype): @dtypes(*torch.testing.get_all_dtypes()) def test_bool_scalar(self, device, dtype): - for N in [30, 300]: + for N in N_values: for foreach_bin_op, foreach_bin_op_, torch_bin_op in zip(self.foreach_bin_ops, self.foreach_bin_ops_, self.torch_bin_ops): @@ -476,7 +478,7 @@ def test_bool_scalar(self, device, dtype): @dtypes(*torch.testing.get_all_dtypes()) def test_bool_scalarlist(self, device, dtype): - for N in [30, 300]: + for N in N_values: for foreach_bin_op, foreach_bin_op_, torch_bin_op in zip(self.foreach_bin_ops_sl, self.foreach_bin_ops_sl_, self.torch_bin_ops): @@ -689,7 +691,7 @@ def test_div_list(self, device, dtype): self.skipTest("Skipped! See https://github.com/pytorch/pytorch/issues/44489") return - for N in [30, 300]: + for N in N_values: tensors1 = self._get_test_data(device, dtype, N) if dtype in [torch.bfloat16, torch.bool, torch.float16]: From e154b36685330a819d96c2e4482ccf1cd06abba4 Mon Sep 17 00:00:00 2001 From: Vaidotas Simkus Date: Tue, 6 Oct 2020 13:40:35 -0700 Subject: [PATCH 05/69] Standardized clamp kernels to Numpy-like implementation (#43288) Summary: **BC-breaking note** For ease of exposition let a_min be the value of the "min" argument to clamp, and a_max be the value of the "max" argument to clamp. This PR changes the behavior of torch.clamp to always compute min(max(a, a_min), a_max). torch.clamp currently computes this in its vectorized CPU specializations: https://github.com/pytorch/pytorch/blob/78b95b6204809822def6dd1b06d03cf002cd30c5/aten/src/ATen/cpu/vec256/vec256_double.h#L304 but in other places it clamps differently: https://github.com/pytorch/pytorch/blob/78b95b6204809822def6dd1b06d03cf002cd30c5/aten/src/ATen/cpu/vec256/vec256_base.h#L624 https://github.com/pytorch/pytorch/blob/78b95b6204809822def6dd1b06d03cf002cd30c5/aten/src/ATen/native/cuda/UnaryOpsKernel.cu#L160 These implementations are the same when a_min < a_max, but divergent when a_min > a_max. This divergence is easily triggered: ``` t = torch.arange(200).to(torch.float) torch.clamp(t, 4, 2)[0] : tensor(2.) torch.clamp(t.cuda(), 4, 2)[0] : tensor(4., device='cuda:0') torch.clamp(torch.tensor(0), 4, 2) : tensor(4) ``` This PR makes the behavior consistent with NumPy's clip. C++'s std::clamp's behavior is undefined when a_min > a_max, but Clang's std::clamp will return 10 in this case (although the program, per the above comment, is in error). Python has no standard clamp implementation. **PR Summary** Fixes discrepancy between AVX, CUDA, and base vector implementation for clamp, such that all implementations are consistent and use min(max_vec, max(min_vec, x) formula, thus making it equivalent to numpy.clip in all implementations. The same fix as in https://github.com/pytorch/pytorch/issues/32587 but isolated to the kernel change only, so that the internal team can benchmark. Pull Request resolved: https://github.com/pytorch/pytorch/pull/43288 Reviewed By: colesbury Differential Revision: D24079453 Pulled By: mruberry fbshipit-source-id: 67f30d2f2c86bbd3e87080b32f00e8fb131a53f7 --- aten/src/ATen/cpu/vec256/vec256_base.h | 33 +--- .../ATen/cpu/vec256/vec256_complex_double.h | 26 --- .../ATen/cpu/vec256/vec256_complex_float.h | 26 --- aten/src/ATen/native/UnaryOps.cpp | 8 +- aten/src/ATen/native/cpu/UnaryOpsKernel.cpp | 15 +- aten/src/ATen/native/cuda/UnaryOpsKernel.cu | 22 ++- test/test_torch.py | 155 ++++++++++-------- torch/_torch_docs.py | 18 +- .../_internal/common_methods_invocations.py | 4 - 9 files changed, 116 insertions(+), 191 deletions(-) diff --git a/aten/src/ATen/cpu/vec256/vec256_base.h b/aten/src/ATen/cpu/vec256/vec256_base.h index b6cc1db24028..edce0e3a2cce 100644 --- a/aten/src/ATen/cpu/vec256/vec256_base.h +++ b/aten/src/ATen/cpu/vec256/vec256_base.h @@ -615,23 +615,12 @@ inline T minimum(const T& a, const T& b) { return c; } -// To save BC, it will not propagate NaN based on IEEE 754 201X template ::value, int>::type = 0> Vec256 inline clamp(const Vec256 &a, const Vec256 &min_vec, const Vec256 &max_vec) { Vec256 c = Vec256(); for (int i = 0; i != Vec256::size(); i++) { - c[i] = a[i] < min_vec[i] ? min_vec[i] : (a[i] > max_vec[i] ? max_vec[i] : a[i]); - } - return c; -} - -template ::value, int>::type = 0> -Vec256 inline clamp(const Vec256 &a, const Vec256 &min_vec, const Vec256 &max_vec) { - Vec256 c = Vec256(); - for (int i = 0; i != Vec256::size(); i++) { - c[i] = std::abs(a[i]) < std::abs(min_vec[i]) ? min_vec[i] : (std::abs(a[i]) > std::abs(max_vec[i]) ? max_vec[i] : a[i]); + c[i] = std::min(std::max(a[i], min_vec[i]), max_vec[i]); } return c; } @@ -646,16 +635,6 @@ Vec256 inline clamp_max(const Vec256 &a, const Vec256 &max_vec) { return c; } -template ::value, int>::type = 0> -Vec256 inline clamp_max(const Vec256 &a, const Vec256 &max_vec) { - Vec256 c = Vec256(); - for (int i = 0; i != Vec256::size(); i++) { - c[i] = std::abs(a[i]) > std::abs(max_vec[i]) ? max_vec[i] : a[i]; - } - return c; -} - template ::value, int>::type = 0> Vec256 inline clamp_min(const Vec256 &a, const Vec256 &min_vec) { @@ -666,16 +645,6 @@ Vec256 inline clamp_min(const Vec256 &a, const Vec256 &min_vec) { return c; } -template ::value, int>::type = 0> -Vec256 inline clamp_min(const Vec256 &a, const Vec256 &min_vec) { - Vec256 c = Vec256(); - for (int i = 0; i != Vec256::size(); i++) { - c[i] = std::abs(a[i]) < std::abs(min_vec[i]) ? min_vec[i] : a[i]; - } - return c; -} - struct Vec256i; #ifdef CPU_CAPABILITY_AVX2 diff --git a/aten/src/ATen/cpu/vec256/vec256_complex_double.h b/aten/src/ATen/cpu/vec256/vec256_complex_double.h index 0827b33a3122..d2ae6f46b44e 100644 --- a/aten/src/ATen/cpu/vec256/vec256_complex_double.h +++ b/aten/src/ATen/cpu/vec256/vec256_complex_double.h @@ -416,32 +416,6 @@ Vec256> inline minimum(const Vec256>& return _mm256_or_pd(min, isnan); } -template <> -Vec256> inline clamp(const Vec256>& a, const Vec256>& min, const Vec256>& max) { - auto abs_a = a.abs_2_(); - auto abs_min = min.abs_2_(); - auto max_mask = _mm256_cmp_pd(abs_a, abs_min, _CMP_LT_OQ); - auto abs_max = max.abs_2_(); - auto min_mask = _mm256_cmp_pd(abs_a, abs_max, _CMP_GT_OQ); - return _mm256_blendv_pd(_mm256_blendv_pd(a, min, max_mask), max, min_mask); -} - -template <> -Vec256> inline clamp_min(const Vec256>& a, const Vec256>& min) { - auto abs_a = a.abs_2_(); - auto abs_min = min.abs_2_(); - auto max_mask = _mm256_cmp_pd(abs_a, abs_min, _CMP_LT_OQ); - return _mm256_blendv_pd(a, min, max_mask); -} - -template <> -Vec256> inline clamp_max(const Vec256>& a, const Vec256>& max) { - auto abs_a = a.abs_2_(); - auto abs_max = max.abs_2_(); - auto min_mask = _mm256_cmp_pd(abs_a, abs_max, _CMP_GT_OQ); - return _mm256_blendv_pd(a, max, min_mask); -} - template <> Vec256> inline operator&(const Vec256>& a, const Vec256>& b) { return _mm256_and_pd(a, b); diff --git a/aten/src/ATen/cpu/vec256/vec256_complex_float.h b/aten/src/ATen/cpu/vec256/vec256_complex_float.h index ea931acc494b..8b4eba07f421 100644 --- a/aten/src/ATen/cpu/vec256/vec256_complex_float.h +++ b/aten/src/ATen/cpu/vec256/vec256_complex_float.h @@ -456,32 +456,6 @@ Vec256> inline minimum(const Vec256>& a, return _mm256_or_ps(min, isnan); } -template <> -Vec256> inline clamp(const Vec256>& a, const Vec256>& min, const Vec256>& max) { - auto abs_a = a.abs_2_(); - auto abs_min = min.abs_2_(); - auto max_mask = _mm256_cmp_ps(abs_a, abs_min, _CMP_LT_OQ); - auto abs_max = max.abs_2_(); - auto min_mask = _mm256_cmp_ps(abs_a, abs_max, _CMP_GT_OQ); - return _mm256_blendv_ps(_mm256_blendv_ps(a, min, max_mask), max, min_mask); -} - -template <> -Vec256> inline clamp_min(const Vec256>& a, const Vec256>& min) { - auto abs_a = a.abs_2_(); - auto abs_min = min.abs_2_(); - auto max_mask = _mm256_cmp_ps(abs_a, abs_min, _CMP_LT_OQ); - return _mm256_blendv_ps(a, min, max_mask); -} - -template <> -Vec256> inline clamp_max(const Vec256>& a, const Vec256>& max) { - auto abs_a = a.abs_2_(); - auto abs_max = max.abs_2_(); - auto min_mask = _mm256_cmp_ps(abs_a, abs_max, _CMP_GT_OQ); - return _mm256_blendv_ps(a, max, min_mask); -} - template <> Vec256> inline operator&(const Vec256>& a, const Vec256>& b) { return _mm256_and_ps(a, b); diff --git a/aten/src/ATen/native/UnaryOps.cpp b/aten/src/ATen/native/UnaryOps.cpp index e2b5639f8dc9..68a1e45f0974 100644 --- a/aten/src/ATen/native/UnaryOps.cpp +++ b/aten/src/ATen/native/UnaryOps.cpp @@ -501,7 +501,7 @@ Tensor signbit(const Tensor& self) { } Tensor& clamp_out(Tensor& result, const Tensor& self, optional min, optional max) { - TORCH_CHECK(!self.is_complex(), "clamp is not yet implemented for complex tensors."); + TORCH_CHECK(!self.is_complex(), "clamp does not support complex inputs."); if (min && max) { TORCH_CHECK(self.layout() == Layout::Strided, "clamp only supports strided layout, got: ", self.layout()); @@ -512,7 +512,7 @@ Tensor& clamp_out(Tensor& result, const Tensor& self, optional min, opti } else if (min) { at::clamp_min_out(result, self, *min); } else { - AT_ERROR("At least one of 'min' or 'max' must not be None"); + TORCH_CHECK(false, "At least one of 'min' or 'max' must not be None"); } return result; } @@ -527,7 +527,7 @@ Tensor& clamp_(Tensor& self, optional min, optional max) { } Tensor& clamp_max_out(Tensor& result, const Tensor& self, Scalar max) { - TORCH_CHECK(!self.is_complex(), "clamp is not yet implemented for complex tensors."); + TORCH_CHECK(!self.is_complex(), "clamp does not support complex inputs."); TORCH_CHECK(self.layout() == Layout::Strided, "clamp_max only supports strided layout, got: ", self.layout()); auto iter = TensorIterator::unary_op(result, self); @@ -545,7 +545,7 @@ Tensor& clamp_max_(Tensor& self, Scalar max) { } Tensor& clamp_min_out(Tensor& result, const Tensor& self, Scalar min) { - TORCH_CHECK(!self.is_complex(), "clamp is not yet implemented for complex tensors."); + TORCH_CHECK(!self.is_complex(), "clamp does not support complex inputs."); TORCH_CHECK(self.layout() == Layout::Strided, "clamp_min only supports strided layout, got: ", self.layout()); auto iter = TensorIterator::unary_op(result, self); diff --git a/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp b/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp index 84c3ceed3a23..beb50ee2c936 100644 --- a/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp @@ -411,36 +411,33 @@ static void nan_to_num_kernel( } static void clamp_kernel(TensorIterator& iter, Scalar min_scalar, Scalar max_scalar) { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, iter.dtype(), "clamp_cpu", [&]() { - c10::scalar_value_type::type (*zabs_)(scalar_t) = zabs; + AT_DISPATCH_ALL_TYPES_AND(kBFloat16, iter.dtype(), "clamp_cpu", [&]() { auto min = min_scalar.to(); auto max = max_scalar.to(); auto min_vec = Vec256(min); auto max_vec = Vec256(max); cpu_kernel_vec(iter, - [=](scalar_t a) -> scalar_t { return zabs_(a) < zabs_(min) ? min : (zabs_(a) > zabs_(max) ? max : a); }, + [=](scalar_t a) -> scalar_t { return std::min(std::max(a, min), max); }, [=](Vec256 a) { return vec256::clamp(a, min_vec, max_vec); }); }); } static void clamp_max_kernel(TensorIterator& iter, Scalar max_scalar) { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, iter.dtype(), "clamp_max_cpu", [&]() { - c10::scalar_value_type::type (*zabs_)(scalar_t) = zabs; + AT_DISPATCH_ALL_TYPES_AND(kBFloat16, iter.dtype(), "clamp_max_cpu", [&]() { auto max = max_scalar.to(); auto max_vec = Vec256(max); cpu_kernel_vec(iter, - [=](scalar_t a) -> scalar_t { return zabs_(a) > zabs_(max) ? max : a; }, + [=](scalar_t a) -> scalar_t { return std::min(a, max); }, [=](Vec256 a) { return vec256::clamp_max(a, max_vec); }); }); } static void clamp_min_kernel(TensorIterator& iter, Scalar min_scalar) { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, iter.dtype(), "clamp_min_cpu", [&]() { - c10::scalar_value_type::type (*zabs_)(scalar_t) = zabs; + AT_DISPATCH_ALL_TYPES_AND(kBFloat16, iter.dtype(), "clamp_min_cpu", [&]() { auto min = min_scalar.to(); auto min_vec = Vec256(min); cpu_kernel_vec(iter, - [=](scalar_t a) -> scalar_t { return zabs_(a) < zabs_(min) ? min : a; }, + [=](scalar_t a) -> scalar_t { return std::max(a, min); }, [=](Vec256 a) { return vec256::clamp_min(a, min_vec); }); }); } diff --git a/aten/src/ATen/native/cuda/UnaryOpsKernel.cu b/aten/src/ATen/native/cuda/UnaryOpsKernel.cu index 5b545471fb34..6f5c9221dee6 100644 --- a/aten/src/ATen/native/cuda/UnaryOpsKernel.cu +++ b/aten/src/ATen/native/cuda/UnaryOpsKernel.cu @@ -12,6 +12,7 @@ #include #include #include +#include #include namespace at { @@ -158,7 +159,12 @@ void clamp_kernel_cuda(TensorIterator& iter, Scalar min_value, Scalar max_value) auto lower = min_value.to(); auto upper = max_value.to(); gpu_kernel(iter, [=]GPU_LAMBDA(scalar_t v) -> scalar_t { - return (v < lower) ? lower : (v > upper ? upper : v); + // Propagate nan, which doesn't propagate automatically for ROCm + if (_isnan(v)) { + return v; + } else { + return ::min(::max(v, lower), upper); + } }); }); } @@ -167,7 +173,12 @@ void clamp_min_kernel_cuda(TensorIterator& iter, Scalar min_value) { AT_DISPATCH_ALL_TYPES_AND(kHalf, iter.dtype(), "clamp_min_cuda", [&]() { auto lower = min_value.to(); gpu_kernel(iter, [=]GPU_LAMBDA(scalar_t v) -> scalar_t { - return v < lower ? lower : v; + // Propagate nan, which doesn't propagate automatically for ROCm + if (_isnan(v)) { + return v; + } else { + return ::max(v, lower); + } }); }); } @@ -176,7 +187,12 @@ void clamp_max_kernel_cuda(TensorIterator& iter, Scalar max_value) { AT_DISPATCH_ALL_TYPES_AND(kHalf, iter.dtype(), "clamp_max_cuda", [&]() { auto upper = max_value.to(); gpu_kernel(iter, [=]GPU_LAMBDA(scalar_t v) -> scalar_t { - return v > upper ? upper : v; + // Propagate nan, which doesn't propagate automatically for ROCm + if (_isnan(v)) { + return v; + } else { + return ::min(v, upper); + } }); }); } diff --git a/test/test_torch.py b/test/test_torch.py index 7da38b211dc5..3ff5a1d73822 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -6386,85 +6386,96 @@ def test_logical_and(self, device, dtypes): def test_logical_or(self, device, dtypes): self._test_logical(device, dtypes, 'logical_or', [10, 0, 1, 0], [1, 0, 0, 10], [1, 0, 1, 1]) + def generate_clamp_baseline(self, device, dtype, *, min_vals, max_vals, with_nans): + """ + Creates a random tensor for a given device and dtype, and computes the expected clamped + values given the min_vals and/or max_vals. + If with_nans is provided, then some values are randomly set to nan. + """ + X = torch.rand(100, device=device).mul(50).add(-25) # uniform in [-25, 25] + X = X.to(dtype) + if with_nans: + mask = torch.randint(0, 2, X.shape, dtype=torch.bool, device=device) + X[mask] = nan + + if isinstance(min_vals, torch.Tensor): + min_vals = min_vals.cpu().numpy() + + if isinstance(max_vals, torch.Tensor): + max_vals = max_vals.cpu().numpy() + + # Use NumPy implementation as reference + X_clamped = torch.tensor(np.clip(X.cpu().numpy(), a_min=min_vals, a_max=max_vals), device=device) + return X, X_clamped + # Tests clamp and its alias, clip - def test_clamp(self, device): - op_list = ((torch.clamp, torch.Tensor.clamp, torch.Tensor.clamp_), - (torch.clip, torch.Tensor.clip, torch.Tensor.clip_)) - for op, method_op, inplace_op in op_list: - - m1 = torch.rand(100, device=device).mul(5).add(-2.5) # uniform in [-2.5, 2.5] - # just in case we're extremely lucky. - min_val = -1 - max_val = 1 - m1[1] = min_val - m1[2] = max_val + @dtypes(torch.int64, torch.float32) + def test_clamp(self, device, dtype): + op_list = (torch.clamp, torch.Tensor.clamp, torch.Tensor.clamp_, + torch.clip, torch.Tensor.clip, torch.Tensor.clip_) - res1 = m1.clone() - inplace_op(res1, min_val, max_val) - res2 = m1.clone() - for i in iter_indices(res2): - res2[i] = max(min_val, min(max_val, res2[i])) - self.assertEqual(res1, res2) + # min/max argument product + args = product((-10, None), (10, None)) - out = m1.clone() - op(m1, min=min_val, max=max_val, out=out) - self.assertEqual(out, res1) + for op in op_list: + for min_val, max_val in args: + if min_val is None and max_val is None: + continue - res1 = op(m1, min=min_val) - res2 = m1.clone() - for i in iter_indices(res2): - res2[i] = max(min_val, res2[i]) - self.assertEqual(res1, res2) + X, Y_expected = self.generate_clamp_baseline(device, dtype, + min_vals=min_val, + max_vals=max_val, + with_nans=False) - op(m1, min=min_val, out=out) - self.assertEqual(out, res1) + # Test op + X1 = X.clone() # So that the in-place ops do not change X + Y_actual = op(X1, min_val, max_val) + self.assertEqual(Y_expected, Y_actual) - res1 = op(m1, max=max_val) - res2 = m1.clone() - for i in iter_indices(res2): - res2[i] = min(max_val, res2[i]) - self.assertEqual(res1, res2) + # Test op-out behavior (out does not exist for method versions) + if op in (torch.clamp, torch.clip): + Y_out = torch.empty_like(X) + op(X, min=min_val, max=max_val, out=Y_out) + self.assertEqual(Y_expected, Y_out) + + def test_clamp_propagates_nans(self, device): + op_list = (torch.clamp, torch.Tensor.clamp, torch.Tensor.clamp_, + torch.clip, torch.Tensor.clip, torch.Tensor.clip_) + + # min/max argument product + args = product((-10, None), (10, None)) + + for op in op_list: + for min_val, max_val in args: + if min_val is None and max_val is None: + continue - op(m1, max=max_val, out=out) - self.assertEqual(out, res1) - - # if the tensor contains nan case - test_tens = torch.tensor([nan], device=device) - - res1 = test_tens.clone() - inplace_op(res1, min_val, max_val) - res2 = test_tens.clone() - for i in iter_indices(res2): - res2[i] = max(min(res2[i], max_val), min_val) - self.assertEqual(torch.isnan(res1), torch.isnan(res2)) - - out = test_tens.clone() - op(test_tens, min=min_val, max=max_val, out=out) - self.assertEqual(torch.isnan(out), torch.isnan(res1)) - - res1 = op(test_tens, min=min_val) - res2 = test_tens.clone() - for i in iter_indices(res2): - res2[i] = max(res2[i], min_val) - self.assertEqual(torch.isnan(res1), torch.isnan(res2)) - - op(test_tens, min=min_val, out=out) - self.assertEqual(torch.isnan(out), torch.isnan(res1)) - - res1 = op(test_tens, max=max_val) - res2 = test_tens.clone() - for i in iter_indices(res2): - res2[i] = min(res2[i], max_val) - self.assertEqual(torch.isnan(res1), torch.isnan(res2)) - - op(test_tens, max=max_val, out=out) - self.assertEqual(torch.isnan(out), torch.isnan(res1)) - - error_msg = 'At least one of \'min\' or \'max\' must not be None' - with self.assertRaisesRegex(RuntimeError, error_msg): - method_op(m1) - with self.assertRaisesRegex(RuntimeError, error_msg): - inplace_op(m1) + X, Y_expected = self.generate_clamp_baseline(device, torch.float, + min_vals=min_val, + max_vals=max_val, + with_nans=True) + Y_expected = torch.isnan(Y_expected) + + # Test op + X1 = X.clone() # So that the in-place ops do not change X + Y_actual = op(X1, min_val, max_val) + self.assertEqual(Y_expected, torch.isnan(Y_actual)) + + # Test op-out behavior (out does not exist for method versions) + if op in (torch.clamp, torch.clip): + Y_out = torch.empty_like(X) + op(X, min_val, max_val, out=Y_out) + self.assertEqual(Y_expected, torch.isnan(Y_out)) + + def test_clamp_raises_arg_errors(self, device): + X = torch.randn(100, dtype=torch.float, device=device) + error_msg = 'At least one of \'min\' or \'max\' must not be None' + with self.assertRaisesRegex(RuntimeError, error_msg): + X.clamp() + with self.assertRaisesRegex(RuntimeError, error_msg): + X.clamp_() + with self.assertRaisesRegex(RuntimeError, error_msg): + torch.clamp(X) @onlyOnCPUAndCUDA @dtypes(torch.float32, torch.float64) diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 6c641c3df140..28f9ebf1a585 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -1651,18 +1651,12 @@ def merge_dicts(*dicts): add_docstr(torch.clamp, r""" clamp(input, min, max, *, out=None) -> Tensor -Clamp all elements in :attr:`input` into the range `[` :attr:`min`, :attr:`max` `]` and return -a resulting tensor: +Clamp all elements in :attr:`input` into the range `[` :attr:`min`, :attr:`max` `]`. +Let min_value and max_value be :attr:`min` and :attr:`max`, respectively, this returns: .. math:: - y_i = \begin{cases} - \text{min} & \text{if } x_i < \text{min} \\ - x_i & \text{if } \text{min} \leq x_i \leq \text{max} \\ - \text{max} & \text{if } x_i > \text{max} - \end{cases} + y_i = \min(\max(x_i, \text{min\_value}), \text{max\_value}) """ + r""" -If :attr:`input` is of type `FloatTensor` or `DoubleTensor`, args :attr:`min` -and :attr:`max` must be real numbers, otherwise they should be integers. Args: {input} @@ -1684,9 +1678,6 @@ def merge_dicts(*dicts): Clamps all elements in :attr:`input` to be larger or equal :attr:`min`. -If :attr:`input` is of type `FloatTensor` or `DoubleTensor`, :attr:`value` -should be a real number, otherwise it should be an integer. - Args: {input} @@ -1706,9 +1697,6 @@ def merge_dicts(*dicts): Clamps all elements in :attr:`input` to be smaller or equal :attr:`max`. -If :attr:`input` is of type `FloatTensor` or `DoubleTensor`, :attr:`value` -should be a real number, otherwise it should be an integer. - Args: {input} diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index f26e6c75d37e..a6887395c19a 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -1659,10 +1659,6 @@ def unpack_variables(args): def exclude_tensor_method(name, test_name): # there are no tensor equivalents for these (inplace or out) exclude_all_tensor_method_by_test_name = { - 'test_clamp_min', - 'test_clamp_max', - 'test_clamp_min_scalar', - 'test_clamp_max_scalar', 'test_slice', 'test_where', 'test_where_broadcast_all', From a3662fa78c42bc2ae6b70fe6f024fb73fed59bcc Mon Sep 17 00:00:00 2001 From: anjali411 Date: Tue, 6 Oct 2020 13:56:47 -0700 Subject: [PATCH 06/69] Minor gradcheck update to reduce computations (#45757) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45757 Test Plan: Imported from OSS Reviewed By: glaringlee Differential Revision: D24137143 Pulled By: anjali411 fbshipit-source-id: e0174ec03d93b1fedf27baa72c3542dac0b70058 --- torch/autograd/gradcheck.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/torch/autograd/gradcheck.py b/torch/autograd/gradcheck.py index 7ca1fccfce54..b2bea4570c2a 100644 --- a/torch/autograd/gradcheck.py +++ b/torch/autograd/gradcheck.py @@ -102,13 +102,11 @@ def fn_out(): d[d_idx] = grad_out.conjugate() * conj_w_d + grad_out * w_d.conj() elif ds_dx.is_complex(): # R -> C # w_d = conj_w_d = 0.5 * ds_dx - dL_dz_conj = 0.5 * (grad_out.conjugate() * ds_dx + grad_out * ds_dx.conj()) - # The above formula is derived for a C -> C function that's a part of - # bigger function with real valued output. From separate calculations, - # it can be verified that the gradient for R -> C function - # equals to real value of the result obtained from the generic formula for - # C -> C functions used above. - d[d_idx] = torch.real(dL_dz_conj) + # dL_dz_conj = 0.5 * [grad_out.conj() * ds_dx + grad_out * ds_dx.conj()] + # = 0.5 * [grad_out.conj() * ds_dx + (grad_out.conj() * ds_dx).conj()] + # = 0.5 * 2 * real(grad_out.conj() * ds_dx) + # = real(grad_out.conj() * ds_dx) + d[d_idx] = torch.real(grad_out.conjugate() * ds_dx) else: # R -> R d[d_idx] = ds_dx * grad_out From 255b0e839f500bf0ab6fbcac3a8e4867ea44cd86 Mon Sep 17 00:00:00 2001 From: lixinyu Date: Tue, 6 Oct 2020 14:55:04 -0700 Subject: [PATCH 07/69] C++ APIs CUDA Stream Note (Set/Get part) (#45754) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45754 Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D24085103 Pulled By: glaringlee fbshipit-source-id: c9641c2baadcf93b84733c037ce91b670dde5f96 --- docs/cpp/source/notes/tensor_cuda_stream.rst | 276 +++++++++++++++++++ 1 file changed, 276 insertions(+) create mode 100644 docs/cpp/source/notes/tensor_cuda_stream.rst diff --git a/docs/cpp/source/notes/tensor_cuda_stream.rst b/docs/cpp/source/notes/tensor_cuda_stream.rst new file mode 100644 index 000000000000..a453f94b5581 --- /dev/null +++ b/docs/cpp/source/notes/tensor_cuda_stream.rst @@ -0,0 +1,276 @@ +Tensor CUDA Stream API +====================== + +A `CUDA Stream`_ is a linear sequence of execution that belongs to a specific CUDA device. +The PyTorch C++ API supports CUDA streams with the CUDAStream class and useful helper functions to make streaming operations easy. +You can find them in `CUDAStream.h`_. This note provides more details on how to use Pytorch C++ CUDA Stream APIs. + +.. _CUDA Stream: https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#streams +.. _CUDAStream.h: https://pytorch.org/cppdocs/api/file_c10_cuda_CUDAStream.h.html#file-c10-cuda-cudastream-h +.. _CUDAStreamGuard.h: https://pytorch.org/cppdocs/api/structc10_1_1cuda_1_1_c_u_d_a_stream_guard.html + +Acquiring CUDA stream +********************* + +Pytorch's C++ API provides the following ways to acquire CUDA stream: + +1. Acquire a new stream from the CUDA stream pool, streams are preallocated from the pool and returned in a round-robin fashion. + +.. code-block:: cpp + + CUDAStream getStreamFromPool(const bool isHighPriority = false, DeviceIndex device = -1); + +.. tip:: + + You can request a stream from the high priority pool by setting isHighPriority to true, or a stream for a specific device + by setting device index (defaulting to the current CUDA stream's device index). + +2. Acquire the default CUDA stream for the passed CUDA device, or for the current device if no device index is passed. + +.. code-block:: cpp + + CUDAStream getDefaultCUDAStream(DeviceIndex device_index = -1); + +.. tip:: + + The default stream is where most computation occurs when you aren't explicitly using streams. + +3. Acquire the current CUDA stream, for the CUDA device with index ``device_index``, or for the current device if no device index is passed. + +.. code-block:: cpp + + CUDAStream getCurrentCUDAStream(DeviceIndex device_index = -1); + +.. tip:: + + The current CUDA stream will usually be the default CUDA stream for the device, but it may be different if someone + called ``setCurrentCUDAStream`` or used ``StreamGuard`` or ``CUDAStreamGuard``. + + + +Set CUDA stream +*************** + +Pytorch's C++ API provides the following ways to set CUDA stream: + +1. Set the current stream on the device of the passed in stream to be the passed in stream. + +.. code-block:: cpp + + void setCurrentCUDAStream(CUDAStream stream); + +.. attention:: + + This function may have nosthing to do with the current device. It only changes the current stream on the stream's device. + We recommend using ``CUDAStreamGuard``, instead, since it switches to the stream's device and makes it the current stream on that device. + ``CUDAStreamGuard`` will also restore the current device and stream when it's destroyed + +2. Use ``CUDAStreamGuard`` to switch to a CUDA stream within a scope, it is defined in `CUDAStreamGuard.h`_ + +.. tip:: + + Use ``CUDAMultiStreamGuard`` if you need to set streams on multiple CUDA devices. + +CUDA Stream Usage Examples +************************** + +1. Acquiring and setting CUDA stream on the same device + +.. code-block:: cpp + + // This example shows how to acquire and set CUDA stream on the same device. + // `at::cuda::setCurrentCUDAStream` is used to set current CUDA stream + + // create a tensor on device 0 + torch::Tensor tensor0 = torch::ones({2, 2}, torch.device(torch::kCUDA)); + // get a new CUDA stream from CUDA stream pool on device 0 + at::cuda::CUDAStream myStream = at::cuda::getStreamFromPool(); + // set current CUDA stream from default stream to `myStream` on device 0 + at::cuda::setCurrentCUDAStream(myStream); + // sum() on tensor0 uses `myStream` as current CUDA stream + tensor0.sum(); + + // get the default CUDA stream on device 0 + at::cuda::CUDAStream defaultStream = at::cuda::getDefaultCUDAStream(); + // set current CUDA stream back to default CUDA stream on devide 0 + at::cuda::setCurrentCUDAStream(defaultStream); + // sum() on tensor0 uses `defaultStream` as current CUDA stream + tensor0.sum(); + +.. code-block:: cpp + + // This example is the same as previous example, but explicitly specify device + // index and use CUDA stream guard to set current CUDA stream + + // create a tensor on device 0 + torch::Tensor tensor0 = torch::ones({2, 2}, torch.device(torch::kCUDA)); + // get a new stream from CUDA stream pool on device 0 + at::cuda::CUDAStream myStream = at::cuda::getStreamFromPool(false, 0); + // set the current CUDA stream to `myStream` within the scope using CUDA stream guard + { + at::cuda::CUDAStreamGuard guard(myStream); + // current CUDA stream is `myStream` from here till the end of bracket. + // sum() on tensor0 uses `myStream` as current CUDA stream + tensor0.sum(); + } + // current CUDA stream is reset to default CUDA stream after CUDA stream guard is destroyed + // sum() on tensor0 uses default CUDA stream on device 0 as current CUDA stream + tensor0.sum(); + +.. attention:: + + Above code is running on the same CUDA device. `setCurrentCUDAStream` will always set current CUDA stream on current device, + but note that `setCurrentCUDASteram` actually set current stream on the device of passed in CUDA stream. + + +2. Acquiring and setting CUDA streams on multiple devices. + +.. code-block:: cpp + + // This example shows how to acquire and set CUDA stream on two devices. + + // acquire new CUDA streams from CUDA stream pool on device 0 and device 1 + at::cuda::CUDAStream myStream0 = at::cuda::getStreamFromPool(false, 0); + at::cuda::CUDAStream myStream1 = at::cuda::getStreamFromPool(false, 1); + + // set current CUDA stream to `myStream0` on device 0 + at::cuda::setCurrentCUDAStream(myStream0); + // set current CUDA stream to `myStream1` on device 1 + at::cuda::setCurrentCUDAStream(myStream1); + + // create a tensor on device 0, no need to specify device index since + // current device index is 0 + torch::Tensor tensor0 = torch::ones({2, 2}, torch::device(at::kCUDA)); + // sum() on tensor0 use `myStream0` as current CUDA stream on device 0 + tensor0.sum(); + + // change the current device index to 1 by using CUDA device guard within a braket scope + { + at::cuda::CUDAGuard device_guard{1}; + // create a tensor on device 1 + torch::Tensor tensor1 = torch::ones({2, 2}, torch::device(at::kCUDA)); + // sum() on tensor 1 uses `myStream1` as current CUDA stream on device 1 + tensor1.sum(); + } + + // current device is reset to device 0 after device_guard is destroyed + + // acquire a new CUDA stream on device 1 + at::cuda::CUDAStream myStream1_1 = at::cuda::getStreamFromPool(false, 1); + // create a new tensor on device 1 + torch::Tensor tensor1 = torch::ones({2, 2}, torch.device({torch::kCUDA, 1})); + + // change the current device index to 1 and current CUDA stream on device 1 + // to `myStream1_1` using CUDA stream guard within a scope + { + at::cuda::CUDAStreamGuard stream_guard(myStream1_1); + // sum() on tensor1 use `myStream1_1` as current CUDA stream on device 1 + tensor1.sum(); + } + + // current device is reset to device 0 and current CUDA stream on device 1 is + // reset to `myStream1` + + // sum() on tensor1 uses `myStream1` as current CUDA stream on device 1 + tensor1.sum(); + + +3. Working with CUDA multistream guard + +.. code-block:: cpp + + // This example shows how to use CUDA multistream guard to set + // two streams on two devices at the same time. + + // create two tensor, one on device 0, one on device 1 + torch::Tensor tensor0 = torch::ones({2, 2}, torch::device({torch::kCUDA, 0})); + torch::Tensor tensor1 = torch::ones({2, 2}, torch::device({torch::kCUDA, 1})); + + // acquire new CUDA streams from CUDA stream pool on device 0 and device 1 + at::cuda::CUDAStream myStream0 = at::cuda::getStreamFromPool(false, 0); + at::cuda::CUDAStream myStream1 = at::cuda::getStreamFromPool(false, 1); + + // set current CUDA stream on device 0 to `myStream0` and + // set current CUDA stream on device 1 to `myStream1` CUDA using multistream guard + { + at::cuda::CUDAMultiStreamGuard multi_guard({myStream0, myStream1}); + + // sum() on tensor0 uses `myStream0` as current CUDA stream on device 0 + tensor0.sum(); + // sum() on tensor1 uses `myStream1` as current CUDA stream on device 1 + tensor1.sum(); + } + + // current CUDA stream on device 0 is reset to default CUDA stream on device 0 + // current CUDA stream on device 1 is reset to default CUDA stream on device 1 + + // sum() on tensor0 uses default CUDA stream as current CUDA stream on device 0 + tensor0.sum(); + // sum() on tensor1 uses defualt CUDA stream as current CUDA stream on device 1 + tensor1.sum(); + +.. attention:: + ``CUDAMultiStreamGuard`` does not change current device index, it only changes the stream on + each passed in stream's device. Other than scope controlling, this guard is equivalent to + calling ``setCurrentCUDAStream`` on each passed in stream. + +4. A skeleton example for handling CUDA streams on multiple devices + +.. code-block:: cpp + + // This is a skeleton example that shows how to handle CUDA streams on multiple devices + // Suppose you want to do work on the non-default stream on two devices simultaneously, and we + // already have streams on both devices in two vectors. The following code shows three ways + // of acquiring and setting the streams. + + // Usage 0: acquire CUDA stream and set current CUDA stream with `setCurrentCUDAStream` + // Create a CUDA stream vector `streams0` on device 0 + std::vector streams0 = + {at::cuda::getDefaultCUDAStream(), at::cuda::getStreamFromPool()}; + // set current stream as `streams0[0]` on device 0 + at::cuda::setCurrentCUDAStream(streams0[0]); + + // create a CUDA stream vector `streams1` on device using CUDA device guard + std::vector streams1; + { + // device index is set to 1 within this scope + at::cuda::CUDAGuard device_guard(1); + streams1.push_back(at::cuda::getDefaultCUDAStream()); + streams1.push_back(at::cuda::getStreamFromPool()); + } + // device index is reset to 0 after device_guard is destroyed + + // set current stream as `streams1[0]` on device 1 + at::cuda::setCurrentCUDAStream(streams1[0]); + + + // Usage 1: use CUDA device guard to change the current device index only + { + at::cuda::CUDAGuard device_guard(1); + + // current device index is changed to 1 within scope + // current CUDA stream is still `streams1[0]` on device 1, no change + } + // current device index is reset to 0 after `device_guard` is destroyed + + + // Usage 2: use CUDA stream guard to change both current device index and current CUDA stream. + { + at::cuda::CUDAStreamGuard stream_guard(streams1[1]); + + // current device index and current CUDA stream are set to 1 and `streams1[1]` within scope + } + // current device index and current CUDA stream are reset to 0 and `streams0[0]` after + // stream_guard is destroyed + + + // Usage 3: use CUDA multi-stream guard to change multiple streams on multiple devices + { + // This is the same as calling `torch::cuda::setCurrentCUDAStream` on both streams + at::cuda::CUDAMultiStreamGuard multi_guard({streams0[1], streams1[1]}); + + // current device index is not change, still 0 + // current CUDA stream on device 0 and device 1 are set to `streams0[1]` and `streams1[1]` + } + // current CUDA stream on device 0 and device 1 are reset to `streams0[0]` and `streams1[0]` + // after `multi_guard` is destroyed. From 5072728d8810e9f9ffdb5bb6d89861ec91767943 Mon Sep 17 00:00:00 2001 From: Ansley Ussery Date: Tue, 6 Oct 2020 15:02:21 -0700 Subject: [PATCH 08/69] Fix stride printing/parsing formatting (#45156) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45156 Test Plan: Imported from OSS Reviewed By: gmagogsfm Differential Revision: D24078695 Pulled By: ansley fbshipit-source-id: dab993277d43b31105c38d12098c37653747b42a --- aten/src/ATen/core/type.cpp | 15 +++- docs/source/onnx.rst | 44 +++++----- test/cpp/jit/test_constant_pooling.cpp | 4 +- test/cpp/jit/test_gpu.cpp | 8 +- test/cpp/jit/test_interpreter.cpp | 2 +- test/cpp/jit/test_irparser.cpp | 8 +- test/cpp/jit/test_misc.cpp | 4 +- test/cpp/tensorexpr/test_kernel.cpp | 50 ++++++------ test/cpp/tensorexpr/test_te_fuser_pass.cpp | 80 +++++++++---------- test/test_jit.py | 16 ++-- .../csrc/jit/frontend/schema_type_parser.cpp | 21 ++--- .../jit/passes/onnx/preprocess_for_onnx.cpp | 23 +++--- torch/onnx/__init__.py | 12 +-- torch/onnx/utils.py | 8 +- 14 files changed, 153 insertions(+), 142 deletions(-) diff --git a/aten/src/ATen/core/type.cpp b/aten/src/ATen/core/type.cpp index 13e82d434647..4b0df2afc1d3 100644 --- a/aten/src/ATen/core/type.cpp +++ b/aten/src/ATen/core/type.cpp @@ -27,7 +27,7 @@ std::ostream& operator<<(std::ostream & out, const Type & t) { out << "Tensor"; } if (auto ndim = value->sizes().size()) { - bool has_valid_strides_info = + bool has_valid_strides_info = *ndim > 0 && value->strides().isComplete() && value->strides().size() == ndim; out << "("; @@ -41,10 +41,17 @@ std::ostream& operator<<(std::ostream & out, const Type & t) { } else { out << "*"; } - if (has_valid_strides_info && - type_verbosity() >= TypeVerbosity::TypeAndStride) { - out << ":" << *value->strides()[i]; + } + if (has_valid_strides_info && + type_verbosity() >= TypeVerbosity::TypeAndStride) { + out << ", strides=["; + for (size_t i = 0; i < *ndim; ++i) { + if (i > 0) { + out << ", "; + } + out << *value->strides()[i]; } + out << "]"; } if (type_verbosity() >= TypeVerbosity::Full) { if (value->requiresGrad()) { diff --git a/docs/source/onnx.rst b/docs/source/onnx.rst index 3c07486b0e89..655016cb19df 100644 --- a/docs/source/onnx.rst +++ b/docs/source/onnx.rst @@ -649,16 +649,16 @@ This mode is used to export all operators as regular ONNX operators. This is the Example torch ir graph: - graph(%0 : Float(2:12, 3:4, 4:1)): - %3 : Float(2:12, 3:4, 4:1) = aten:exp(%0) - %4 : Float(2:12, 3:4, 4:1) = aten:div(%0, %3) + graph(%0 : Float(2, 3, 4, strides=[12, 4, 1])): + %3 : Float(2, 3, 4, strides=[12, 4, 1]) = aten:exp(%0) + %4 : Float(2, 3, 4, strides=[12, 4, 1]) = aten:div(%0, %3) return (%4) Is exported as: - graph(%0 : Float(2:12, 3:4, 4:1)): - %1 : Float(2:12, 3:4, 4:1) = onnx:Exp(%0) - %2 : Float(2:12, 3:4, 4:1) = onnx:Div(%0, %1) + graph(%0 : Float(2, 3, 4, strides=[12, 4, 1])): + %1 : Float(2, 3, 4, strides=[12, 4, 1]) = onnx:Exp(%0) + %2 : Float(2, 3, 4, strides=[12, 4, 1]) = onnx:Div(%0, %1) return (%2) @@ -668,16 +668,16 @@ This mode is used to export all operators as ATen ops, and avoid conversion to O Example torch ir graph: - graph(%0 : Float(2:12, 3:4, 4:1)): - %3 : Float(2:12, 3:4, 4:1) = aten::exp(%0) - %4 : Float(2:12, 3:4, 4:1) = aten::div(%0, %3) + graph(%0 : Float(2, 3, 4, strides=[12, 4, 1])): + %3 : Float(2, 3, 4, strides=[12, 4, 1]) = aten::exp(%0) + %4 : Float(2, 3, 4, strides=[12, 4, 1]) = aten::div(%0, %3) return (%4) Is exported as: - graph(%0 : Float(2:12, 3:4, 4:1)): - %1 : Float(2:12, 3:4, 4:1) = aten::ATen[operator="exp"](%0) - %2 : Float(2:12, 3:4, 4:1) = aten::ATen[operator="div"](%0, %1) + graph(%0 : Float(2, 3, 4, strides=[12, 4, 1])): + %1 : Float(2, 3, 4, strides=[12, 4, 1]) = aten::ATen[operator="exp"](%0) + %2 : Float(2, 3, 4, strides=[12, 4, 1]) = aten::ATen[operator="div"](%0, %1) return (%2) ONNX_ATEN_FALLBACK @@ -707,7 +707,7 @@ To export a raw ir. :: Example torch ir graph: - graph(%x.1 : Float(1:1)): + graph(%x.1 : Float(1, strides=[1])): %1 : Tensor = aten::exp(%x.1) %2 : Tensor = aten::div(%x.1, %1) %y.1 : Tensor[] = prim::ListConstruct(%2) @@ -715,7 +715,7 @@ To export a raw ir. :: is exported as: - graph(%x.1 : Float(1:1)): + graph(%x.1 : Float(1, strides=[1])): %1 : Tensor = aten::exp(%x.1) %2 : Tensor = aten::div(%x.1, %1) %y.1 : Tensor[] = prim::ListConstruct(%2) @@ -729,18 +729,18 @@ enables users to register and implement the operator as part of their runtime ba Example torch ir graph: - graph(%0 : Float(2:12, 3:4, 4:1), - %1 : Float(2:12, 3:4, 4:1)): - %6 : Float(2:12, 3:4, 4:1) = foo_namespace::bar(%0, %1) # custom op - %7 : Float(2:12, 3:4, 4:1) = aten::div(%6, %0) # registered op + graph(%0 : Float(2, 3, 4, strides=[12, 4, 1]), + %1 : Float(2, 3, 4, strides=[12, 4, 1])): + %6 : Float(2, 3, 4, strides=[12, 4, 1]) = foo_namespace::bar(%0, %1) # custom op + %7 : Float(2, 3, 4, strides=[12, 4, 1]) = aten::div(%6, %0) # registered op return (%7)) is exported as: - graph(%0 : Float(2:12, 3:4, 4:1), - %1 : Float(2:12, 3:4, 4:1)): - %2 : Float(2:12, 3:4, 4:1) = foo_namespace::bar(%0, %1) # custom op - %3 : Float(2:12, 3:4, 4:1) = onnx::Div(%2, %0) # registered op + graph(%0 : Float(2, 3, 4, strides=[12, 4, 1]), + %1 : Float(2, 3, 4, strides=[12, 4, 1])): + %2 : Float(2, 3, 4, strides=[12, 4, 1]) = foo_namespace::bar(%0, %1) # custom op + %3 : Float(2, 3, 4, strides=[12, 4, 1]) = onnx::Div(%2, %0) # registered op return (%3 diff --git a/test/cpp/jit/test_constant_pooling.cpp b/test/cpp/jit/test_constant_pooling.cpp index c8cb58e1886a..8479c96742b0 100644 --- a/test/cpp/jit/test_constant_pooling.cpp +++ b/test/cpp/jit/test_constant_pooling.cpp @@ -79,11 +79,11 @@ graph(): ConstantPooling(graph); testing::FileCheck() .check_count( - "Float(2:1, requires_grad=0, device=cpu) = prim::Constant", + "Float(2, strides=[1], requires_grad=0, device=cpu) = prim::Constant", 1, /*exactly*/ true) ->check_count( - "Long(2:1, requires_grad=0, device=cpu) = prim::Constant", + "Long(2, strides=[1], requires_grad=0, device=cpu) = prim::Constant", 1, /*exactly*/ true) ->run(*graph); diff --git a/test/cpp/jit/test_gpu.cpp b/test/cpp/jit/test_gpu.cpp index 38008d417256..f41149c37fe9 100644 --- a/test/cpp/jit/test_gpu.cpp +++ b/test/cpp/jit/test_gpu.cpp @@ -1085,10 +1085,10 @@ TEST(NVFuserTest, FusionDependency_CUDA) { TEST(NVFuserTest, FusionParser_CUDA) { auto g = std::make_shared(); const auto graph0_string = R"IR( - graph(%0 : Float(2:1), - %1 : Float(2:1)): - %c0 : Float(2:1) = aten::mul(%0, %1) - %d0 : Float(2:1) = aten::mul(%c0, %0) + graph(%0 : Float(2, strides=[1]), + %1 : Float(2, strides=[1])): + %c0 : Float(2, strides=[1]) = aten::mul(%0, %1) + %d0 : Float(2, strides=[1]) = aten::mul(%c0, %0) return (%d0))IR"; torch::jit::parseIR(graph0_string, g.get()); diff --git a/test/cpp/jit/test_interpreter.cpp b/test/cpp/jit/test_interpreter.cpp index da4607d7f047..3b90bd107fd4 100644 --- a/test/cpp/jit/test_interpreter.cpp +++ b/test/cpp/jit/test_interpreter.cpp @@ -19,7 +19,7 @@ class TypeCheckTest : public ::testing::Test { R"IR( graph(%a.1 : Tensor, %b.1 : Tensor): - %t0 : Float(2:2, 2:1, device=cpu, requires_grad=1), %t1 : Float(3:3, 3:1), %type_matched : bool = prim::TypeCheck(%a.1, %b.1) + %t0 : Float(2, 2, strides=[2, 1], device=cpu, requires_grad=1), %t1 : Float(3, 3, strides=[3, 1]), %type_matched : bool = prim::TypeCheck(%a.1, %b.1) return (%t0, %t1, %type_matched) )IR", &*graph, diff --git a/test/cpp/jit/test_irparser.cpp b/test/cpp/jit/test_irparser.cpp index 57f21f5bf5f9..6db8ba26639d 100644 --- a/test/cpp/jit/test_irparser.cpp +++ b/test/cpp/jit/test_irparser.cpp @@ -269,7 +269,7 @@ TEST(IRParserTest, Strides) { parseIR( R"IR( graph(%a : Float(4, 5), - %b : Float(4:5, 5:1), + %b : Float(4, 5, strides=[5, 1]), %c : Double(*, *)): return (%a) )IR", @@ -303,7 +303,7 @@ TEST(IRParserTest, MalformedStrides) { bool error_thrown = false; EXPECT_ANY_THROW(parseIR( R"IR( -graph(%a : Float(4:5, 5)): +graph(%a : Float(4, strides=[5], 5)): return (%a) )IR", &*graph, @@ -314,7 +314,7 @@ TEST(IRParserTest, TensorShapes) { checkRoundtrip( R"IR( graph(%a : Float(4, 5), - %b : Float(4:5, 5:1), + %b : Float(4, 5, strides=[5, 1]), %c : Double(*, *)): return (%a) )IR"); @@ -327,7 +327,7 @@ graph(%a : Float(*, *, device=cpu), %b : Float(*, *, requires_grad=1), %c : Long(5, 10, requires_grad=1, device=cpu), %d : Float(5, requires_grad=0, device=cuda:2), - %e : Long(4:6, 3:2, 2:1, requires_grad=0, device=cuda:1), + %e : Long(4, 3, 1, strides=[6, 2, 1], requires_grad=0, device=cuda:1), %f : Float(), %g : Float(device=cpu), %h : Float(requires_grad=1), diff --git a/test/cpp/jit/test_misc.cpp b/test/cpp/jit/test_misc.cpp index d205ae3d58db..ca4fb2e7620d 100644 --- a/test/cpp/jit/test_misc.cpp +++ b/test/cpp/jit/test_misc.cpp @@ -1756,8 +1756,8 @@ TEST(ProfilerTest, Basic) { is.run(stack); // profiled types are stored as attributes and show up in the dump, e.g. - // Tensor = prim::profile[profiled_type=Double(4:256, 256:1, requires_grad=0, - // device=cpu) + // Tensor = prim::profile[profiled_type=Double(4, 256, strides=[256, 1], + // requires_grad=0, device=cpu) testing::FileCheck() .check("Tensor = prim::profile[profiled_type") ->check_same("256") diff --git a/test/cpp/tensorexpr/test_kernel.cpp b/test/cpp/tensorexpr/test_kernel.cpp index d80710fa732b..2b2cec70358d 100644 --- a/test/cpp/tensorexpr/test_kernel.cpp +++ b/test/cpp/tensorexpr/test_kernel.cpp @@ -21,10 +21,10 @@ void testKernel_1() { KernelScope kernel_scope; const auto graph_string = R"IR( - graph(%0 : Float(5:3,3:1, device=cpu), - %1 : Float(5:3,3:1, device=cpu)): - %2 : Float(5:3,3:1) = aten::mul(%0, %1) - %3 : Float(5:3,3:1) = aten::mul(%0, %2) + graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), + %1 : Float(5, 3, strides=[3, 1], device=cpu)): + %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1) + %3 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2) return (%3))IR"; auto graph = std::make_shared(); parseIR(graph_string, &*graph); @@ -60,10 +60,10 @@ void testKernel_2() { KernelScope kernel_scope; const auto graph_string = R"IR( - graph(%0 : Float(5:3,3:1, device=cpu), - %1 : Float(5:1,3:5, device=cpu)): - %2 : Float(5:3,3:1) = aten::mul(%0, %1) - %3 : Float(5:3,3:1) = aten::mul(%0, %2) + graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), + %1 : Float(5, 3, strides=[1, 5], device=cpu)): + %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1) + %3 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2) return (%3))IR"; auto graph = std::make_shared(); parseIR(graph_string, &*graph); @@ -100,10 +100,10 @@ void testKernel_3() { KernelScope kernel_scope; const auto graph_string = R"IR( - graph(%0 : Float(5:3,3:1, device=cpu), - %1 : Float(5:12,3:2, device=cpu)): - %2 : Float(5:3,3:1) = aten::mul(%0, %1) - %3 : Float(5:3,3:1) = aten::mul(%0, %2) + graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), + %1 : Float(5, 3, strides=[12, 2], device=cpu)): + %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1) + %3 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2) return (%3))IR"; auto graph = std::make_shared(); parseIR(graph_string, &*graph); @@ -143,8 +143,8 @@ void testKernel_4() { KernelScope kernel_scope; const auto graph_string = R"IR( - graph(%0 : Float(5:3, 3:1, device=cpu), - %1 : Float(5:12, 3:2, device=cpu)): + graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), + %1 : Float(5, 3, strides=[12, 2], device=cpu)): %2 : Tensor = aten::mul(%0, %1) %3 : Tensor = aten::mul(%0, %2) return (%3))IR"; @@ -182,8 +182,8 @@ void testKernel_4() { KernelScope kernel_scope; const auto graph_string = R"IR( - graph(%0 : Float(8:8, 8:1, device=cpu), - %1 : Float(8:8, 8:1, device=cpu)): + graph(%0 : Float(8, 8, strides=[8, 1], device=cpu), + %1 : Float(8, 8, strides=[8, 1], device=cpu)): %2 : Tensor = aten::mul(%0, %1) %3 : Tensor, %4 : Tensor = prim::ConstantChunk[dim=1,chunks=2](%2) %r : Tensor = aten::mul(%3, %4) @@ -223,9 +223,9 @@ void testKernel_4() { KernelScope kernel_scope; const auto graph_string = R"IR( - graph(%a : Float(4:2, 2:1, device=cpu), - %b : Float(4:6, 3:2, 2:1, device=cpu), - %c : Float(3:4, 2:2, 2:1, device=cpu)): + graph(%a : Float(4, 2, strides=[2, 1], device=cpu), + %b : Float(4, 3, 2, strides=[6, 2, 1], device=cpu), + %c : Float(3, 2, 2, strides=[4, 2, 1], device=cpu)): %one : int = prim::Constant[value=1]() %minus_one : int = prim::Constant[value=-1]() %three : int = prim::Constant[value=3]() @@ -286,9 +286,9 @@ void testKernel_4() { KernelScope kernel_scope; const auto graph_string = R"IR( - graph(%a : Float(5:6, 3:2, 2:1, device=cpu), - %b : Float(5:14, 7:2, 2:1, device=cpu), - %c : Float(5:18, 9:2, 2:1, device=cpu)): + graph(%a : Float(5, 3, 2, strides=[6, 2, 1], device=cpu), + %b : Float(5, 7, 2, strides=[14, 2, 1], device=cpu), + %c : Float(5, 9, 2, strides=[18, 2, 1], device=cpu)): %dim : int = prim::Constant[value=1]() %inputs : Tensor[] = prim::ListConstruct(%a, %b, %c) %r : Tensor = aten::cat(%inputs, %dim) # new size: [5,19,2] @@ -363,7 +363,7 @@ at::Tensor iotaTensor(IntArrayRef sizes, const at::TensorOptions& options) { void testKernelSumAllAxes() { // Test lowering of sum on all axes. const auto graph_template = R"IR( - graph(%0 : Float(5:3,3:1, device=cpu)): + graph(%0 : Float(5, 3, strides=[3, 1], device=cpu)): %1 : ${dtype} %2 : Tensor = aten::sum(%0, %1) return (%2))IR"; @@ -410,7 +410,7 @@ void testKernelSumAllAxes() { void testKernelSumOneAxis() { // Test lowering of sum on one axis. const auto graph_template = R"IR( - graph(%0 : Float(5:3,3:1, device=cpu)): + graph(%0 : Float(5, 3, strides=[3, 1], device=cpu)): %1 : int[] = prim::Constant[value=[${dim}]]() %2 : bool = prim::Constant[value=${keepdim}]() %3 : ${dtype} @@ -466,7 +466,7 @@ void testKernelSumOneAxis() { void testKernelSumMultipleAxes() { // Test lowering of sum on multiple axes. const auto graph_template = R"IR( - graph(%0 : Float(2:18,3:6,2:3,3:1, device=cpu)): + graph(%0 : Float(2, 3, 2, 3, strides=[18, 6, 3, 1], device=cpu)): %1 : int = prim::Constant[value=${dim1}]() %2 : int = prim::Constant[value=${dim2}]() %3 : int[] = prim::ListConstruct(%1, %2) diff --git a/test/cpp/tensorexpr/test_te_fuser_pass.cpp b/test/cpp/tensorexpr/test_te_fuser_pass.cpp index 826cf7209346..0ad4df33019c 100644 --- a/test/cpp/tensorexpr/test_te_fuser_pass.cpp +++ b/test/cpp/tensorexpr/test_te_fuser_pass.cpp @@ -28,14 +28,14 @@ void testFuserPass_1() { WithCPUFuser cf; KernelScope kernel_scope; const auto graph_string = R"IR( - graph(%0 : Float(128:1, device=cpu), - %1 : Float(128:1, device=cpu)): + graph(%0 : Float(128, strides=[1], device=cpu), + %1 : Float(128, strides=[1], device=cpu)): %12 : int = prim::Constant[value=1]() - %2.1 : Float(128:1, device=cpu) = aten::mul(%0, %1) - %2 : Float(128:1, device=cpu) = aten::mul(%2.1, %1) - %3 : Float(128:1, device=cpu) = aten::add_(%2, %1, %12) - %4 : Float(128:1, device=cpu) = aten::mul(%2, %1) - %5 : Float(128:1, device=cpu) = aten::add(%2, %4, %12) + %2.1 : Float(128, strides=[1], device=cpu) = aten::mul(%0, %1) + %2 : Float(128, strides=[1], device=cpu) = aten::mul(%2.1, %1) + %3 : Float(128, strides=[1], device=cpu) = aten::add_(%2, %1, %12) + %4 : Float(128, strides=[1], device=cpu) = aten::mul(%2, %1) + %5 : Float(128, strides=[1], device=cpu) = aten::add(%2, %4, %12) return (%5))IR"; auto g = std::make_shared(); torch::jit::parseIR(graph_string, g.get()); @@ -55,13 +55,13 @@ void testFuserPass_2() { WithCPUFuser cf; KernelScope kernel_scope; const auto graph_string = R"IR( - graph(%0 : Float(128:1, device=cpu), - %1 : Float(128:1, device=cpu)): + graph(%0 : Float(128, strides=[1], device=cpu), + %1 : Float(128, strides=[1], device=cpu)): %12 : int = prim::Constant[value=1]() - %a : Float(128:1, device=cpu) = aten::mul(%0, %1) - %b : Float(128:1, device=cpu) = aten::add(%0, %1, %12) - %c : Float(128:1, device=cpu) = aten::add_(%b, %1, %12) - %d : Float(128:1, device=cpu) = aten::mul(%c, %a) + %a : Float(128, strides=[1], device=cpu) = aten::mul(%0, %1) + %b : Float(128, strides=[1], device=cpu) = aten::add(%0, %1, %12) + %c : Float(128, strides=[1], device=cpu) = aten::add_(%b, %1, %12) + %d : Float(128, strides=[1], device=cpu) = aten::mul(%c, %a) return (%d))IR"; auto g = std::make_shared(); torch::jit::parseIR(graph_string, g.get()); @@ -80,9 +80,9 @@ void testFuserPass_3() { WithCPUFuser cf; KernelScope kernel_scope; const auto graph_string = R"IR( - graph(%x : Float(128:1, device=cpu), - %y : Float(128:1, device=cpu)): - %r : Float(128:1, device=cpu) = aten::mul(%x, %y) + graph(%x : Float(128, strides=[1], device=cpu), + %y : Float(128, strides=[1], device=cpu)): + %r : Float(128, strides=[1], device=cpu) = aten::mul(%x, %y) return (%r))IR"; { auto g = std::make_shared(); @@ -129,9 +129,9 @@ void testFuserPass_UnfusibleDevice() { WithCPUFuser cf(false); KernelScope kernel_scope; const auto graph_string = R"IR( - graph(%x : Float(10:1, device=cpu), - %y : Float(10:1, device=cpu)): - %a : Float(10:1, device=cpu) = aten::mul(%x, %y) + graph(%x : Float(10, strides=[1], device=cpu), + %y : Float(10, strides=[1], device=cpu)): + %a : Float(10, strides=[1], device=cpu) = aten::mul(%x, %y) return (%a))IR"; auto g = std::make_shared(); torch::jit::parseIR(graph_string, g.get()); @@ -167,9 +167,9 @@ void testFuserPass_Multidevice() { WithCPUFuser cf; KernelScope kernel_scope; const auto graph_string = R"IR( - graph(%x : Float(10:1, device=cpu), - %y : Float(20:1, device=cpu), - %z : Float(30:1, device=cpu)): + graph(%x : Float(10, strides=[1], device=cpu), + %y : Float(20, strides=[1], device=cpu), + %z : Float(30, strides=[1], device=cpu)): %dim : int = prim::Constant[value=0]() %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) %cat : Tensor = aten::cat(%xyz_list, %dim) @@ -187,9 +187,9 @@ void testFuserPass_Multidevice() { WithCPUFuser cf; KernelScope kernel_scope; const auto graph_string = R"IR( - graph(%x : Float(10:1, device=cpu), - %y : Float(20:1, device=cuda:0), - %z : Float(30:1, device=cpu)): + graph(%x : Float(10, strides=[1], device=cpu), + %y : Float(20, strides=[1], device=cuda:0), + %z : Float(30, strides=[1], device=cpu)): %dim : int = prim::Constant[value=0]() %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) %cat : Tensor = aten::cat(%xyz_list, %dim) @@ -208,9 +208,9 @@ void testFuserPass_Multidevice() { WithCPUFuser cf; KernelScope kernel_scope; const auto graph_string = R"IR( - graph(%x : Float(10:1, device=cpu), - %y : Float(20:1, device=cpu), - %z : Float(10:1, device=cuda:0)): + graph(%x : Float(10, strides=[1], device=cpu), + %y : Float(20, strides=[1], device=cpu), + %z : Float(10, strides=[1], device=cuda:0)): %dim : int = prim::Constant[value=0]() %xy_list : Tensor[] = prim::ListConstruct(%x, %y) %xy_cat : Tensor = aten::cat(%xy_list, %dim) @@ -230,9 +230,9 @@ void testFuserPass_Multidevice() { WithCPUFuser cf; KernelScope kernel_scope; const auto graph_string = R"IR( - graph(%x : Float(10:1, device=cpu), - %y : Float(20:1, device=cpu), - %z : Float(10:1, device=cuda:0)): + graph(%x : Float(10, strides=[1], device=cpu), + %y : Float(20, strides=[1], device=cpu), + %z : Float(10, strides=[1], device=cuda:0)): %z2 : Tensor = aten::mul(%z, %z) %dim : int = prim::Constant[value=0]() %xy_list : Tensor[] = prim::ListConstruct(%x, %y, %z2) @@ -252,8 +252,8 @@ void testFuserPass_Multidevice() { WithCPUFuser cf; KernelScope kernel_scope; const auto graph_string = R"IR( - graph(%x : Float(10:1, device=cpu), - %y : Float(20:1, device=cuda:0)): + graph(%x : Float(10, strides=[1], device=cpu), + %y : Float(20, strides=[1], device=cuda:0)): %r : Tensor = aten::mul(%x, %y) return (%r))IR"; auto g = std::make_shared(); @@ -269,9 +269,9 @@ void testFuserPass_Multidevice() { WithCPUFuser cf; KernelScope kernel_scope; const auto graph_string = R"IR( - graph(%x : Float(10:1, device=cuda:0), - %y : Float(20:1, device=cuda:1), - %z : Float(20:1, device=cpu)): + graph(%x : Float(10, strides=[1], device=cuda:0), + %y : Float(20, strides=[1], device=cuda:1), + %z : Float(20, strides=[1], device=cpu)): %x2 : Tensor = aten::mul(%x, %x) %y2 : Tensor = aten::mul(%y, %y) %z2 : Tensor = aten::mul(%z, %z) @@ -292,10 +292,10 @@ void testFuserPass_MergeGroups() { WithCPUFuser cf; KernelScope kernel_scope; const auto graph_string = R"IR( - graph(%a : Float(128:1, device=cpu), - %b : Float(128:1, device=cpu)): - %x : Float(128:1, device=cpu) = aten::mul(%a, %a) - %y : Float(128:1, device=cpu) = aten::mul(%b, %b) + graph(%a : Float(128, strides=[1], device=cpu), + %b : Float(128, strides=[1], device=cpu)): + %x : Float(128, strides=[1], device=cpu) = aten::mul(%a, %a) + %y : Float(128, strides=[1], device=cpu) = aten::mul(%b, %b) return (%x, %y))IR"; auto g = std::make_shared(); torch::jit::parseIR(graph_string, g.get()); diff --git a/test/test_jit.py b/test/test_jit.py index d093a4b8826e..5baa240e30b8 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -1299,7 +1299,7 @@ def broadcast(a, b): graph = torch.jit.script(broadcast).graph torch._C._jit_pass_complete_shape_analysis(graph, (x, y), False) - FileCheck().check("Double(4:120, 3:40, 8:5, 5:1, device=cpu)").run(str(graph)) + FileCheck().check("Double(4, 3, 8, 5, strides=[120, 40, 5, 1], device=cpu)").run(str(graph)) def test_shape_analysis_unsqueeze_in_loop(self): input_str = """graph(%x.1 : Tensor): @@ -2804,8 +2804,8 @@ def test_not_const(x): test_not_const(torch.rand([2, 2])) graph_str = torch.jit.last_executed_optimized_graph() - FileCheck().check("profiled_type=Double(*:2, 2:1, requires_grad=0, device=cpu").run(graph_str) - FileCheck().check_not("profiled_type=Double(1:2, 2:1, requires_grad=0, device=cpu").run(graph_str) + FileCheck().check("profiled_type=Double(*, 2, strides=[2, 1], requires_grad=0, device=cpu").run(graph_str) + FileCheck().check_not("profiled_type=Double(1, 2, strides=[2, 1], requires_grad=0, device=cpu").run(graph_str) def test_nested_bailouts(self): @@ -9775,7 +9775,7 @@ def forward(self): cm = ScriptMod(Mod()) # specialized tensor in graph - FileCheck().check("Double(1:3, 3:1, requires_grad=0, device=cpu)").run(cm.forward.graph) + FileCheck().check("Double(1, 3, strides=[3, 1], requires_grad=0, device=cpu)").run(cm.forward.graph) buffer = io.BytesIO() torch.jit.save(cm, buffer) buffer.seek(0) @@ -10334,7 +10334,7 @@ def foo(x, y): a = torch.zeros(2, 2) b = torch.zeros(4, dtype=torch.long) torch._C._jit_pass_complete_shape_analysis(foo.graph, (a, b), False) - FileCheck().check("Double(2:4, 4:1, requires_grad=0, device=cpu)").run(str(foo.graph)) + FileCheck().check("Double(2, 4, strides=[4, 1], requires_grad=0, device=cpu)").run(str(foo.graph)) def test_shape_analysis_loop(self): def foo(a, b, x): @@ -10622,8 +10622,8 @@ def test_rand(): out = fn() graph_str = torch.jit.last_executed_optimized_graph() self.assertEqual(out.dtype, torch.double) - FileCheck().check("Double(3:4, 4:1, requires_grad=0, device=cpu)") \ - .check_not("Float(3:4, 4:1, requires_grad=0, device=cpu)").run(graph_str) + FileCheck().check("Double(3, 4, strides=[4, 1], requires_grad=0, device=cpu)") \ + .check_not("Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)").run(graph_str) # fn = self.checkScript(test_rand, ()) # out = fn() @@ -10640,7 +10640,7 @@ def randint(): out = randint() graph_str = torch.jit.last_executed_optimized_graph() self.assertEqual(out.dtype, torch.double) - FileCheck().check("profiled_type=Double(1:2, 2:1, requires_grad=0, device=cpu)").run(graph_str) + FileCheck().check("profiled_type=Double(1, 2, strides=[2, 1], requires_grad=0, device=cpu)").run(graph_str) def test_erase_number_types(self): diff --git a/torch/csrc/jit/frontend/schema_type_parser.cpp b/torch/csrc/jit/frontend/schema_type_parser.cpp index a089abf7fb2c..6dd970378bf0 100644 --- a/torch/csrc/jit/frontend/schema_type_parser.cpp +++ b/torch/csrc/jit/frontend/schema_type_parser.cpp @@ -190,7 +190,7 @@ TypePtr SchemaTypeParser::parseRefinedTensor() { // unknown sizes, a mix of ranks with known and unknown sizes, or ranks with // known sizes and strides. The type might also have requires_grad and/or // device option. Examples of types we're handling here: - // Long(10:48,8:6,6:1, requires_grad=0, device=cuda:1) + // Long(10, 8, 6, strides=[48, 6, 1], requires_grad=0, device=cuda:1) // Float(10, *, 20, device=cuda:1) // Float(requires_grad=1) std::vector> dims; @@ -220,6 +220,17 @@ TypePtr SchemaTypeParser::parseRefinedTensor() { } return; } + if (field == "strides") { + seen_strides = true; + L.expect('='); + parseList('[', ',', ']', [&] { + const std::string& num = L.expect(TK_NUMBER).text(); + std::string::size_type num_len; + size_t stride = c10::stoi(num, &num_len); + strides.push_back(stride); + }); + return; + } throw ErrorReport(L.cur()) << "Unexpected specifier '" << field << "'"; } if (device.has_value() || requires_grad.has_value()) { @@ -241,14 +252,6 @@ TypePtr SchemaTypeParser::parseRefinedTensor() { std::string::size_type num_len; size_t dim = c10::stoi(num, &num_len); dims.emplace_back(dim); - if (seen_strides || L.cur().kind == ':') { - L.expect(':'); - seen_strides = true; - const std::string& num = L.expect(TK_NUMBER).text(); - std::string::size_type num_len; - size_t stride = c10::stoi(num, &num_len); - strides.push_back(stride); - } }); if (seen_strides) { at::IntArrayRef strides_ref(strides); diff --git a/torch/csrc/jit/passes/onnx/preprocess_for_onnx.cpp b/torch/csrc/jit/passes/onnx/preprocess_for_onnx.cpp index 3c62b2877fa5..1f583fc3fd5d 100644 --- a/torch/csrc/jit/passes/onnx/preprocess_for_onnx.cpp +++ b/torch/csrc/jit/passes/onnx/preprocess_for_onnx.cpp @@ -37,22 +37,23 @@ at::optional FindFusibleListUnpack(Node* n) { // split.Tensor(Tensor(a) self, int split_size, int dim=0) -> Tensor(a)[] // split_with_sizes(Tensor self, int[] split_sizes, int dim=0) -> Tensor[] // -// graph(%input : Float(5:12, 4:3, 3:1)): +// graph(%input : Float(5, 4, 3, strides=[12, 3, 1])): // %13 : int[] = prim::Constant[value=[2, 1, 2]]() // %7 : int = prim::Constant[value=0]() // %8 : Tensor[] = aten::split_with_sizes(%input, %13, %7) -// %9 : Float(2:12, 4:3, 3:1), %10 : Float(1:12, 4:3, 3:1), %11 : Float(2:12, -// 4:3, 3:1) = prim::ListUnpack(%8) return (%9, %10, %11) +// %9 : Float(2, 4, 3, strides=[12, 3, 1]), %10 : Float(1, 4, 3, strides=[12, +// 3, 1]), %11 : Float(2, 4, 3, strides=[12, 3, 1]) = prim::ListUnpack(%8) +// return (%9, %10, %11) // // After fusion -// graph(%input : Float(5:12, 4:3, 3:1)): +// graph(%input : Float(5, 4, 3, strides=[12, 3, 1])): // %13 : int[] = prim::Constant[value=[2, 1, 2]]() // %7 : int = prim::Constant[value=0]() // %8 : int = prim::Constant[value=3]() # Adding addtional input of value 3 // representing the number of outputs. -// %14 : Float(2:12, 4:3, 3:1), %15 : Float(1:12, 4:3, 3:1), %16 : Float(2:12, -// 4:3, 3:1) = aten::split_with_sizes(%input, %13, %7, %8) -// return (%14, %15, %16) +// %14 : Float(2, 4, 3, strides=[12, 3, 1]), %15 : Float(1, 4, 3, strides=[12, +// 3, 1]), %16 : Float(2, 4, 3, strides=[12, 3, 1] = +// aten::split_with_sizes(%input, %13, %7, %8) return (%14, %15, %16) void FuseWithListUnpack(Node* n) { auto found_listUnpack = FindFusibleListUnpack(n); if (!found_listUnpack) { @@ -108,8 +109,8 @@ static void FuseWithListUnpack(Block* b) { // when inputs to the add node are two int lists // // before the pass: -// graph(%x.1 : Float(2:12, 3:4, 4:1, requires_grad=0, device=cpu), -// %y.1 : Float(1:6, 2:3, 3:1, requires_grad=0, device=cpu)): +// graph(%x.1 : Float(2, 3, 4, strides=[12, 4, 1], requires_grad=0, device=cpu), +// %y.1 : Float(1, 2, 3, strides=[6, 3, 1], requires_grad=0, device=cpu)): // %2 : None = prim::Constant() // %3 : int[] = aten::size(%x.1) // %l1.1 : int[] = aten::list(%3 @@ -120,8 +121,8 @@ static void FuseWithListUnpack(Block* b) { // return (%8) // // after the pass: -// graph(%x.1 : Float(2:12, 3:4, 4:1, requires_grad=0, device=cpu), -// %y.1 : Float(1:6, 2:3, 3:1, requires_grad=0, device=cpu)): +// graph(%x.1 : Float(2, 3, 4, strides=[12, 4, 1], requires_grad=0, device=cpu), +// %y.1 : Float(1, 2, 3, strides=[6, 3, 1], requires_grad=0, device=cpu)): // %2 : None = prim::Constant() // %3 : int[] = aten::size(%x.1) // %l1.1 : int[] = aten::list(%3) diff --git a/torch/onnx/__init__.py b/torch/onnx/__init__.py index 4ec89e4c9b0b..255c15b9da4a 100644 --- a/torch/onnx/__init__.py +++ b/torch/onnx/__init__.py @@ -102,12 +102,12 @@ def export(model, args, f, export_params=True, verbose=False, training=TrainingM exporter falls back on this op. OperatorExportTypes.RAW: Export raw ir. OperatorExportTypes.ONNX_FALLTHROUGH: If an op is not supported - in ONNX, fall through and export the operator as is, as a custom + in ONNX, fall through and export the operator as is, as a custom ONNX op. Using this mode, the op can be exported and implemented by the user for their runtime backend. Example graph:: - graph(%x.1 : Long(1:1)):: + graph(%x.1 : Long(1, strides=[1])):: %1 : None = prim::Constant() %2 : Tensor = aten::sum(%x.1, %1) %y.1 : Tensor[] = prim::ListConstruct(%2) @@ -115,7 +115,7 @@ def export(model, args, f, export_params=True, verbose=False, training=TrainingM is exported as:: - graph(%x.1 : Long(1:1)):: + graph(%x.1 : Long(1, strides=[1])):: %1 : Tensor = onnx::ReduceSum[keepdims=0](%x.1) %y.1 : Long() = prim::ListConstruct(%1) return (%y.1) @@ -212,13 +212,13 @@ def export(model, args, f, export_params=True, verbose=False, training=TrainingM external_data_format (bool, default False): If True, then the model is exported in ONNX external data format, in which case some of the model parameters are stored in external binary files and not in the ONNX model file itself. See link for format - details: + details: https://github.com/onnx/onnx/blob/8b3f7e2e7a0f2aba0e629e23d89f07c7fc0e6a5e/onnx/onnx.proto#L423 Also, in this case, argument 'f' must be a string specifying the location of the model. - The external binary files will be stored in the same location specified by the model + The external binary files will be stored in the same location specified by the model location 'f'. If False, then the model is stored in regular format, i.e. model and parameters are all in one file. This argument is ignored for all export types other - than ONNX. + than ONNX. """ from torch.onnx import utils diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index 98dc79d6546c..16e7af721ebf 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -533,18 +533,18 @@ def _find_missing_ops_onnx_export(model, args, f, verbose=False, training=Traini the user for their runtime backend. Example graph:: - graph(%0 : Float(2:12, 3:4, 4:1, requires_grad=0, device=cpu)): + graph(%0 : Float(2, 3, 4, strides=[12, 4, 1], requires_grad=0, device=cpu)): %6 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() %4 : None = prim::Constant() - %5 : Float(2:12, 3:4, 4:1, requires_grad=0, device=cpu) = aten::cumsum(%0, %6, %4) # main.py:6:0 + %5 : Float(2, 3, 4, strides=[12, 4, 1], requires_grad=0, device=cpu) = aten::cumsum(%0, %6, %4) # main.py:6:0 return (%5) is exported as:: - graph(%0 : Float(2:12, 3:4, 4:1, requires_grad=0, device=cpu)): + graph(%0 : Float(2, 3, 4, strides=[12, 4, 1], requires_grad=0, device=cpu)): %6 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() %4 : None = prim::Constant() - %5 : Float(2:12, 3:4, 4:1, requires_grad=0, device=cpu) = aten::cumsum(%0, %6, %4) # main.py:6:0 + %5 : Float(2, 3, 4, strides=[12, 4, 1], requires_grad=0, device=cpu) = aten::cumsum(%0, %6, %4) # main.py:6:0 return (%5) In the above example, aten::cumsum in not implemented in opset 9, hence exporter falls From 14997f2125aa9d49edda94043c1ea9b8f2a692ef Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Tue, 6 Oct 2020 15:31:11 -0700 Subject: [PATCH 09/69] [quant][graphmode][fx] Add warning for unsupported case (#45714) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45714 Hit the problem when writing a test like following: ``` class M(...): def forward(self, x): x = x.some_op() return x ``` we need to know the scope of `x` to figure out the qconfig for `x` Test Plan: Imported from OSS Reviewed By: z-a-f Differential Revision: D24069959 fbshipit-source-id: 95ac8963c802ebce5d0e54d55f5ebb42085ca8a6 --- torch/quantization/fx/quantize.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/torch/quantization/fx/quantize.py b/torch/quantization/fx/quantize.py index 14bd2c8eee1e..74dee6ea3cf3 100644 --- a/torch/quantization/fx/quantize.py +++ b/torch/quantization/fx/quantize.py @@ -47,6 +47,7 @@ ) from collections import OrderedDict +import warnings import copy import re @@ -297,7 +298,13 @@ def get_qconfig(module_name): elif node.op == 'call_method': self_obj = node.args[0] # qconfig for call_method should be the same as the `self` object for the call - self.qconfig_map[node.name] = self.qconfig_map[self_obj.name] + if self_obj.name in self.qconfig_map: + qconfig = self.qconfig_map[self_obj.name] + else: + # need scope info for each node to support this + warnings.warn("Scope info is not yet supported, taking default qconfig for value {}".format(node.name)) + qconfig = get_qconfig('') + self.qconfig_map[node.name] = qconfig elif node.op == 'call_module': module_qconfig = get_qconfig(node.target) # regex is not supported eager mode propagate_qconfig_, we'll need to From 5ff31620b75226f22b4ffa5f9f90132c98702134 Mon Sep 17 00:00:00 2001 From: Bert Maher Date: Tue, 6 Oct 2020 16:51:59 -0700 Subject: [PATCH 10/69] [te] Add a 2D convolution example test (#45514) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45514 Test Plan: Imported from OSS Reviewed By: SplitInfinity Differential Revision: D24142405 Pulled By: bertmaher fbshipit-source-id: 8f064d0638b48f55a732c08938b9fcf1ba3f0415 --- test/cpp/tensorexpr/test_conv.cpp | 86 +++++++++++++++++++++++++++++++ test/cpp/tensorexpr/tests.h | 3 +- 2 files changed, 88 insertions(+), 1 deletion(-) create mode 100644 test/cpp/tensorexpr/test_conv.cpp diff --git a/test/cpp/tensorexpr/test_conv.cpp b/test/cpp/tensorexpr/test_conv.cpp new file mode 100644 index 000000000000..d4d1c46e2044 --- /dev/null +++ b/test/cpp/tensorexpr/test_conv.cpp @@ -0,0 +1,86 @@ +#include +#include +#include +#include +#include + +namespace torch { +namespace jit { + +namespace te = torch::jit::tensorexpr; +namespace F = torch::nn::functional; + +void testConv2D() { + te::KernelScope kernel_scope; + + // Input dimensions. + constexpr int N = 1; + constexpr int C = 3; + constexpr int H = 11; + constexpr int W = 11; + + // Filter dimensions. + constexpr int K = 8; + constexpr int R = 3; + constexpr int S = 3; + + // Output dims. + constexpr int OH = H - R + 1; + constexpr int OW = W - S + 1; + + // Compute reference result. + at::Tensor input = torch::randn({N, C, H, W}); + at::Tensor filter = torch::randn({K, C, R, S}); + at::Tensor ref = F::conv2d(input, filter); + + // Double check the output size is as expected. + ASSERT_EQ(ref.size(0), N); + ASSERT_EQ(ref.size(1), K); + ASSERT_EQ(ref.size(2), OH); + ASSERT_EQ(ref.size(3), OW); + + te::Placeholder inputB(te::BufHandle("input", {N, C, H, W}, te::kFloat)); + te::Placeholder filterB(te::BufHandle("filter", {K, C, R, S}, te::kFloat)); + + te::Tensor* conv = te::Reduce( + "conv", + {{N, "n"}, {K, "k"}, {OH, "oh"}, {OW, "ow"}}, + te::Sum(), + // FIXME: We have to use a `std::vector` parameter here and then unpack + // it, because we don't have an overload allowing for an arbitrary number + // of ExprHandle/VarHandle parameters. + [&](const std::vector& v) { + auto const& n = v[0]; + auto const& k = v[1]; + auto const& oh = v[2]; + auto const& ow = v[3]; + auto const& c = v[4]; + auto const& r = v[5]; + auto const& s = v[6]; + // FIXME: We have to use `call` and construct a `std::vector` here + // because the `operator()` overload is only specialized for a small + // number of arguments. + return inputB.load(n, c, oh + r, ow + s) * filterB.load(k, c, r, s); + }, + // FIXME: If you forget one of the reduction dims, you get a segfault. + // Could that be caught by a verifier? + {{C, "c"}, {R, "r"}, {S, "s"}}); + + // FIXME: It'd be nice to have a single header that pulls in things like + // LoopNest, IRSimplifier, etc. + te::LoopNest loop({conv}); + loop.prepareForCodegen(); + te::Stmt* s = loop.root_stmt(); + s = te::IRSimplifier::simplify(s); + + at::Tensor result = at::empty_like(ref); + te::SimpleIREvaluator cg(s, {inputB, filterB, conv}); + cg.call({input.data_ptr(), + filter.data_ptr(), + result.data_ptr()}); + + ASSERT_TRUE(at::allclose(ref, result, 1e-3, 1e-3)); +} + +} // namespace jit +} // namespace torch diff --git a/test/cpp/tensorexpr/tests.h b/test/cpp/tensorexpr/tests.h index dc21373f241f..4337c14fe3eb 100644 --- a/test/cpp/tensorexpr/tests.h +++ b/test/cpp/tensorexpr/tests.h @@ -298,7 +298,8 @@ namespace jit { _(FuserPass_UnknownShapesIgnored) \ _(FuserPass_Multidevice) \ _(FuserPass_MergeGroups) \ - _(TrainBasic) + _(TrainBasic) \ + _(Conv2D) #define TH_FORALL_TENSOREXPR_TESTS_LLVM(_) \ _(LLVMByteImmTest) \ From 50f89578ddc8c6806e22e818e9b895e588133b95 Mon Sep 17 00:00:00 2001 From: Bert Maher Date: Tue, 6 Oct 2020 16:51:59 -0700 Subject: [PATCH 11/69] [te] Add a benchmark harness (#45875) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45875 Adds a googlebenchmark harness for perf testing programs generated by tensorexpr, sans any pytorch wrappings (for python-level benchmarks of tensorexpr, see benchmarks/tensorexpr). Currently there's a harness for gemm that sets up the problem using torch (and also measures the perf of a torch::mm to give a baseline). Right now there's just an unoptimized implementation that is expected to be not very fast. More optimized versions are coming. Sample output from my dev box: ``` Run on (48 X 2501 MHz CPU s) CPU Caches: L1 Data 32K (x24) L1 Instruction 32K (x24) L2 Unified 256K (x24) L3 Unified 30720K (x2) -------------------------------------------------------------------------------------------- Benchmark Time CPU Iterations UserCounters... -------------------------------------------------------------------------------------------- Gemm/Torch/128/128/128 73405 ns 73403 ns 8614 GFLOPS=57.1411G/s Gemm/TensorExprNoopt/128/128/128 3073003 ns 3072808 ns 229 GFLOPS=1.36497G/s ``` Test Plan: Imported from OSS Reviewed By: SplitInfinity Differential Revision: D24142403 Pulled By: bertmaher fbshipit-source-id: 3354aaa56868a43a553acd1ad9a192f28d8e3597 --- benchmarks/cpp/tensorexpr/CMakeLists.txt | 2 + benchmarks/cpp/tensorexpr/tensorexpr.cpp | 67 ++++++++++++++++++++++++ caffe2/CMakeLists.txt | 4 ++ 3 files changed, 73 insertions(+) create mode 100644 benchmarks/cpp/tensorexpr/CMakeLists.txt create mode 100644 benchmarks/cpp/tensorexpr/tensorexpr.cpp diff --git a/benchmarks/cpp/tensorexpr/CMakeLists.txt b/benchmarks/cpp/tensorexpr/CMakeLists.txt new file mode 100644 index 000000000000..c047423c1f9c --- /dev/null +++ b/benchmarks/cpp/tensorexpr/CMakeLists.txt @@ -0,0 +1,2 @@ +add_executable(tensorexpr_bench tensorexpr.cpp) +target_link_libraries(tensorexpr_bench PRIVATE torch_library benchmark) diff --git a/benchmarks/cpp/tensorexpr/tensorexpr.cpp b/benchmarks/cpp/tensorexpr/tensorexpr.cpp new file mode 100644 index 000000000000..a39ba48f1b79 --- /dev/null +++ b/benchmarks/cpp/tensorexpr/tensorexpr.cpp @@ -0,0 +1,67 @@ +#include +#include "torch/torch.h" +#include "torch/csrc/jit/tensorexpr/ir_simplifier.h" +#include "torch/csrc/jit/tensorexpr/loopnest.h" +#include "torch/csrc/jit/tensorexpr/tensor.h" + +namespace te = torch::jit::tensorexpr; + +class Gemm : public benchmark::Fixture { + public: + void SetUp(const benchmark::State& state) { + M = state.range(0); + N = state.range(1); + K = state.range(2); + A = torch::randn({M, K}); + B = torch::randn({K, N}); + C = torch::mm(A, B); + } + + void TearDown(benchmark::State& state) { + state.counters["GFLOPS"] = + benchmark::Counter(uint64_t(state.iterations()) * 2 * M * N * K, + benchmark::Counter::kIsRate); + } + + int M; + int N; + int K; + at::Tensor A; + at::Tensor B; + at::Tensor C; +}; + +BENCHMARK_DEFINE_F(Gemm, Torch)(benchmark::State& state) { + for (auto _ : state) { + torch::mm_out(C, A, B); + } +} + +BENCHMARK_DEFINE_F(Gemm, TensorExprNoopt)(benchmark::State& state) { + te::KernelScope ks; + + te::Placeholder AP(te::BufHandle("A", {M, K}, te::kFloat)); + te::Placeholder BP(te::BufHandle("B", {K, N}, te::kFloat)); + te::Tensor* CT = te::Reduce( + "gemm", + {{M, "M"}, {N, "N"}}, + te::Sum(), + [&](const te::ExprHandle& m, const te::ExprHandle& n, const te::ExprHandle& k) { + return AP.load(m, k) * BP.load(k, n); + }, + {{K, "K"}}); + te::LoopNest loop({CT}); + loop.prepareForCodegen(); + te::Stmt* s = loop.root_stmt(); + s = te::IRSimplifier::simplify(s); + auto cg = CreateCodeGen("llvm_codegen", s, {AP, BP, CT}); + + for (auto _ : state) { + cg->call({A.data_ptr(), B.data_ptr(), C.data_ptr()}); + } +} + +BENCHMARK_REGISTER_F(Gemm, Torch)->Args({128, 128, 128}); +BENCHMARK_REGISTER_F(Gemm, TensorExprNoopt)->Args({128, 128, 128}); + +BENCHMARK_MAIN(); diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 318e46a44f54..fe5240118b2f 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -1261,6 +1261,10 @@ if(BUILD_STATIC_RUNTIME_BENCHMARK) target_link_libraries(static_runtime_test torch_library gtest_main) endif() +if(BUILD_TENSOREXPR_BENCHMARK) + add_subdirectory(${TORCH_ROOT}/benchmarks/cpp/tensorexpr ${CMAKE_BINARY_DIR}/tensorexpr_bench) +endif() + if(BUILD_MOBILE_BENCHMARK) foreach(benchmark_src ${ATen_MOBILE_BENCHMARK_SRCS}) get_filename_component(benchmark_name ${benchmark_src} NAME_WE) From f2e569461b7eed09ba9b84d3a3574907d097f9a5 Mon Sep 17 00:00:00 2001 From: Bert Maher Date: Tue, 6 Oct 2020 16:51:59 -0700 Subject: [PATCH 12/69] [te] Tiled (m=32 x n=32) gemm benchmark (#45905) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45905 Test Plan: Imported from OSS Reviewed By: SplitInfinity Differential Revision: D24142402 Pulled By: bertmaher fbshipit-source-id: b39e18b6985ee1c1f654fba4498ed91ff14d8d5f --- benchmarks/cpp/tensorexpr/tensorexpr.cpp | 79 +++++++++++++++++++++--- 1 file changed, 72 insertions(+), 7 deletions(-) diff --git a/benchmarks/cpp/tensorexpr/tensorexpr.cpp b/benchmarks/cpp/tensorexpr/tensorexpr.cpp index a39ba48f1b79..b57e845de7eb 100644 --- a/benchmarks/cpp/tensorexpr/tensorexpr.cpp +++ b/benchmarks/cpp/tensorexpr/tensorexpr.cpp @@ -1,8 +1,8 @@ #include -#include "torch/torch.h" #include "torch/csrc/jit/tensorexpr/ir_simplifier.h" #include "torch/csrc/jit/tensorexpr/loopnest.h" #include "torch/csrc/jit/tensorexpr/tensor.h" +#include "torch/torch.h" namespace te = torch::jit::tensorexpr; @@ -18,9 +18,9 @@ class Gemm : public benchmark::Fixture { } void TearDown(benchmark::State& state) { - state.counters["GFLOPS"] = - benchmark::Counter(uint64_t(state.iterations()) * 2 * M * N * K, - benchmark::Counter::kIsRate); + state.counters["GFLOPS"] = benchmark::Counter( + uint64_t(state.iterations()) * 2 * M * N * K, + benchmark::Counter::kIsRate); } int M; @@ -46,9 +46,9 @@ BENCHMARK_DEFINE_F(Gemm, TensorExprNoopt)(benchmark::State& state) { "gemm", {{M, "M"}, {N, "N"}}, te::Sum(), - [&](const te::ExprHandle& m, const te::ExprHandle& n, const te::ExprHandle& k) { - return AP.load(m, k) * BP.load(k, n); - }, + [&](const te::ExprHandle& m, + const te::ExprHandle& n, + const te::ExprHandle& k) { return AP.load(m, k) * BP.load(k, n); }, {{K, "K"}}); te::LoopNest loop({CT}); loop.prepareForCodegen(); @@ -61,7 +61,72 @@ BENCHMARK_DEFINE_F(Gemm, TensorExprNoopt)(benchmark::State& state) { } } +BENCHMARK_DEFINE_F(Gemm, TensorExprTile32x32)(benchmark::State& state) { + te::KernelScope ks; + + te::Placeholder AP(te::BufHandle("A", {M, K}, te::kFloat)); + te::Placeholder BP(te::BufHandle("B", {K, N}, te::kFloat)); + te::Tensor* CT = te::Reduce( + "gemm", + {{M, "M"}, {N, "N"}}, + te::Sum(), + [&](const te::ExprHandle& m, + const te::ExprHandle& n, + const te::ExprHandle& k) { return AP.load(m, k) * BP.load(k, n); }, + {{K, "K"}}); + te::LoopNest loop({CT}); + + { + auto const& loops = loop.getLoopStmtsFor(CT); + te::For* m = loops[0]; + te::For* mo; + te::For* mi; + loop.splitWithMask(m, 32, &mo, &mi); + } + { + auto const& loops = loop.getLoopStmtsFor(CT); + te::For* n = loops[2]; + te::For* no; + te::For* ni; + loop.splitWithMask(n, 32, &no, &ni); + } + // mo, mi, no, ni, k -> + // mo, no, mi, ni, k + { + auto const& loops = loop.getLoopStmtsFor(CT); + te::For* mi = loops[1]; + te::For* no = loops[2]; + loop.reorderAxis(mi, no); + } + // mo, no, mi, ni, k -> + // mo, no, mi, k, ni + { + auto const& loops = loop.getLoopStmtsFor(CT); + te::For* ni = loops[3]; + te::For* k = loops[4]; + loop.reorderAxis(ni, k); + } + // mo, no, mi, k, ni -> + // mo, no, k, mi, ni + { + auto const& loops = loop.getLoopStmtsFor(CT); + te::For* mi = loops[2]; + te::For* k = loops[3]; + loop.reorderAxis(mi, k); + } + + loop.prepareForCodegen(); + te::Stmt* s = loop.root_stmt(); + s = te::IRSimplifier::simplify(s); + auto cg = CreateCodeGen("llvm_codegen", s, {AP, BP, CT}); + + for (auto _ : state) { + cg->call({A.data_ptr(), B.data_ptr(), C.data_ptr()}); + } +} + BENCHMARK_REGISTER_F(Gemm, Torch)->Args({128, 128, 128}); BENCHMARK_REGISTER_F(Gemm, TensorExprNoopt)->Args({128, 128, 128}); +BENCHMARK_REGISTER_F(Gemm, TensorExprTile32x32)->Args({128, 128, 128}); BENCHMARK_MAIN(); From 624084e6d61b048573c7071d4d7453465479f1a2 Mon Sep 17 00:00:00 2001 From: Bert Maher Date: Tue, 6 Oct 2020 16:51:59 -0700 Subject: [PATCH 13/69] [te][llvm] Enable fused multiply-add (fma) in code generation (#45906) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45906 Test Plan: Imported from OSS Reviewed By: SplitInfinity Differential Revision: D24142404 Pulled By: bertmaher fbshipit-source-id: a8db2e66c1e65bbb255886e165a1773723cbcd20 --- torch/csrc/jit/tensorexpr/llvm_codegen.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp index 2d20bd1b47d0..3873cdd0ebf0 100644 --- a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp +++ b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp @@ -190,6 +190,7 @@ static llvm::orc::JITTargetMachineBuilder makeTargetMachineBuilder() { JTMB.setCodeGenOptLevel(llvm::CodeGenOpt::Default); JTMB.setCPU(llvm::sys::getHostCPUName()); JTMB.addFeatures(SubtargetFeatures.getFeatures()); + JTMB.getOptions().AllowFPOpFusion = llvm::FPOpFusion::Fast; return JTMB; #endif From f5e70a750464fc3b022cecb57f1bfe19242a3884 Mon Sep 17 00:00:00 2001 From: Rong Rong Date: Tue, 6 Oct 2020 17:29:57 -0700 Subject: [PATCH 14/69] fix test flakiness caused by sys.getrefcount(None) (#45876) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45876 sys.getrefcount() can be flaky before/after scope() call Test Plan: buck test mode/opt-asan //caffe2/test:others -- 'test_none_names_refcount \(test_namedtensor\.TestNamedTensor\)' --run-disabled Reviewed By: malfet Differential Revision: D24123724 fbshipit-source-id: 4af0b150222cfb92dd0776a42fcab44d896a772a --- test/test_namedtensor.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/test/test_namedtensor.py b/test/test_namedtensor.py index 5358d2bbab10..72e8bd6dd108 100644 --- a/test/test_namedtensor.py +++ b/test/test_namedtensor.py @@ -146,16 +146,18 @@ def _test_factory(self, factory, device): names65 = ['A' * i for i in range(1, 66)] x = factory([1] * 65, names=names64, device=device) - def test_none_names_refcount(self): + def test_none_names_refcount(self, N=10): def scope(): unnamed = torch.empty(2, 3) unnamed.names # materialize [None, None] prev_none_refcnt = sys.getrefcount(None) - scope() - self.assertEqual(sys.getrefcount(None), prev_none_refcnt, - msg='Using tensor.names should not change ' - 'the refcount of Py_None') + # Ran it N times to reduce flakiness + [scope() for i in range(N)] + after_none_refcnt = sys.getrefcount(None) + self.assertTrue(after_none_refcnt - prev_none_refcnt < N / 2, + msg='Using tensor.names should not change ' + 'the refcount of Py_None') def test_has_names(self): unnamed = torch.empty(2, 3) From 49af42114353cb5386892e250d63779f0803eae3 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Tue, 6 Oct 2020 17:49:09 -0700 Subject: [PATCH 15/69] Embed callgrind headers (#45914) Summary: Because access to https://sourceware.org/git/valgrind.git can be really slow especially in some regions Pull Request resolved: https://github.com/pytorch/pytorch/pull/45914 Reviewed By: seemethere Differential Revision: D24144420 Pulled By: malfet fbshipit-source-id: a454c8c3182c570ec344bf6468bb5e55d8b8da79 --- .gitmodules | 8 +- third_party/valgrind | 1 - third_party/valgrind-headers/README.md | 5 + third_party/valgrind-headers/callgrind.h | 129 + third_party/valgrind-headers/valgrind.h | 7157 ++++++++++++++++++++++ torch/CMakeLists.txt | 3 +- 6 files changed, 7294 insertions(+), 9 deletions(-) delete mode 160000 third_party/valgrind create mode 100644 third_party/valgrind-headers/README.md create mode 100644 third_party/valgrind-headers/callgrind.h create mode 100644 third_party/valgrind-headers/valgrind.h diff --git a/.gitmodules b/.gitmodules index d7a11cc22996..c7de63e5af63 100644 --- a/.gitmodules +++ b/.gitmodules @@ -124,13 +124,9 @@ url = https://github.com/google/XNNPACK.git [submodule "third_party/fmt"] ignore = dirty - path = third_party/fmt - url = https://github.com/fmtlib/fmt.git + path = third_party/fmt + url = https://github.com/fmtlib/fmt.git [submodule "third_party/tensorpipe"] ignore = dirty path = third_party/tensorpipe url = https://github.com/pytorch/tensorpipe.git -[submodule "third_party/valgrind"] - ignore = dirty - path = third_party/valgrind - url = https://sourceware.org/git/valgrind.git diff --git a/third_party/valgrind b/third_party/valgrind deleted file mode 160000 index 2593ccd82c18..000000000000 --- a/third_party/valgrind +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 2593ccd82c189bf40b60a3a4934c5d0bbdb75427 diff --git a/third_party/valgrind-headers/README.md b/third_party/valgrind-headers/README.md new file mode 100644 index 000000000000..98173f37ad6e --- /dev/null +++ b/third_party/valgrind-headers/README.md @@ -0,0 +1,5 @@ +This folder contains 2 Valgrind headers, downloaded from +https://sourceware.org/git/?p=valgrind.git;a=blob;f=callgrind/callgrind.h;hb=HEAD +https://sourceware.org/git/?p=valgrind.git;a=blob;f=include/valgrind.h;hb=HEAD + + diff --git a/third_party/valgrind-headers/callgrind.h b/third_party/valgrind-headers/callgrind.h new file mode 100644 index 000000000000..f078cc82b95d --- /dev/null +++ b/third_party/valgrind-headers/callgrind.h @@ -0,0 +1,129 @@ + +/* + ---------------------------------------------------------------- + + Notice that the following BSD-style license applies to this one + file (callgrind.h) only. The rest of Valgrind is licensed under the + terms of the GNU General Public License, version 2, unless + otherwise indicated. See the COPYING file in the source + distribution for details. + + ---------------------------------------------------------------- + + This file is part of callgrind, a valgrind tool for cache simulation + and call tree tracing. + + Copyright (C) 2003-2017 Josef Weidendorfer. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + 1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + 2. The origin of this software must not be misrepresented; you must + not claim that you wrote the original software. If you use this + software in a product, an acknowledgment in the product + documentation would be appreciated but is not required. + + 3. Altered source versions must be plainly marked as such, and must + not be misrepresented as being the original software. + + 4. The name of the author may not be used to endorse or promote + products derived from this software without specific prior written + permission. + + THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS + OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY + DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE + GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + ---------------------------------------------------------------- + + Notice that the above BSD-style license applies to this one file + (callgrind.h) only. The entire rest of Valgrind is licensed under + the terms of the GNU General Public License, version 2. See the + COPYING file in the source distribution for details. + + ---------------------------------------------------------------- +*/ + +#ifndef __CALLGRIND_H +#define __CALLGRIND_H + +#include "valgrind.h" + +/* !! ABIWARNING !! ABIWARNING !! ABIWARNING !! ABIWARNING !! + This enum comprises an ABI exported by Valgrind to programs + which use client requests. DO NOT CHANGE THE ORDER OF THESE + ENTRIES, NOR DELETE ANY -- add new ones at the end. + + The identification ('C','T') for Callgrind has historical + reasons: it was called "Calltree" before. Besides, ('C','G') would + clash with cachegrind. + */ + +typedef + enum { + VG_USERREQ__DUMP_STATS = VG_USERREQ_TOOL_BASE('C','T'), + VG_USERREQ__ZERO_STATS, + VG_USERREQ__TOGGLE_COLLECT, + VG_USERREQ__DUMP_STATS_AT, + VG_USERREQ__START_INSTRUMENTATION, + VG_USERREQ__STOP_INSTRUMENTATION + } Vg_CallgrindClientRequest; + +/* Dump current state of cost centers, and zero them afterwards */ +#define CALLGRIND_DUMP_STATS \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__DUMP_STATS, \ + 0, 0, 0, 0, 0) + +/* Dump current state of cost centers, and zero them afterwards. + The argument is appended to a string stating the reason which triggered + the dump. This string is written as a description field into the + profile data dump. */ +#define CALLGRIND_DUMP_STATS_AT(pos_str) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__DUMP_STATS_AT, \ + pos_str, 0, 0, 0, 0) + +/* Zero cost centers */ +#define CALLGRIND_ZERO_STATS \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__ZERO_STATS, \ + 0, 0, 0, 0, 0) + +/* Toggles collection state. + The collection state specifies whether the happening of events + should be noted or if they are to be ignored. Events are noted + by increment of counters in a cost center */ +#define CALLGRIND_TOGGLE_COLLECT \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__TOGGLE_COLLECT, \ + 0, 0, 0, 0, 0) + +/* Start full callgrind instrumentation if not already switched on. + When cache simulation is done, it will flush the simulated cache; + this will lead to an artificial cache warmup phase afterwards with + cache misses which would not have happened in reality. */ +#define CALLGRIND_START_INSTRUMENTATION \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__START_INSTRUMENTATION, \ + 0, 0, 0, 0, 0) + +/* Stop full callgrind instrumentation if not already switched off. + This flushes Valgrinds translation cache, and does no additional + instrumentation afterwards, which effectivly will run at the same + speed as the "none" tool (ie. at minimal slowdown). + Use this to bypass Callgrind aggregation for uninteresting code parts. + To start Callgrind in this mode to ignore the setup phase, use + the option "--instr-atstart=no". */ +#define CALLGRIND_STOP_INSTRUMENTATION \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__STOP_INSTRUMENTATION, \ + 0, 0, 0, 0, 0) + +#endif /* __CALLGRIND_H */ diff --git a/third_party/valgrind-headers/valgrind.h b/third_party/valgrind-headers/valgrind.h new file mode 100644 index 000000000000..d33dd30932aa --- /dev/null +++ b/third_party/valgrind-headers/valgrind.h @@ -0,0 +1,7157 @@ +/* -*- c -*- + ---------------------------------------------------------------- + + Notice that the following BSD-style license applies to this one + file (valgrind.h) only. The rest of Valgrind is licensed under the + terms of the GNU General Public License, version 2, unless + otherwise indicated. See the COPYING file in the source + distribution for details. + + ---------------------------------------------------------------- + + This file is part of Valgrind, a dynamic binary instrumentation + framework. + + Copyright (C) 2000-2017 Julian Seward. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + 1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + 2. The origin of this software must not be misrepresented; you must + not claim that you wrote the original software. If you use this + software in a product, an acknowledgment in the product + documentation would be appreciated but is not required. + + 3. Altered source versions must be plainly marked as such, and must + not be misrepresented as being the original software. + + 4. The name of the author may not be used to endorse or promote + products derived from this software without specific prior written + permission. + + THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS + OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY + DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE + GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + ---------------------------------------------------------------- + + Notice that the above BSD-style license applies to this one file + (valgrind.h) only. The entire rest of Valgrind is licensed under + the terms of the GNU General Public License, version 2. See the + COPYING file in the source distribution for details. + + ---------------------------------------------------------------- +*/ + + +/* This file is for inclusion into client (your!) code. + + You can use these macros to manipulate and query Valgrind's + execution inside your own programs. + + The resulting executables will still run without Valgrind, just a + little bit more slowly than they otherwise would, but otherwise + unchanged. When not running on valgrind, each client request + consumes very few (eg. 7) instructions, so the resulting performance + loss is negligible unless you plan to execute client requests + millions of times per second. Nevertheless, if that is still a + problem, you can compile with the NVALGRIND symbol defined (gcc + -DNVALGRIND) so that client requests are not even compiled in. */ + +#ifndef __VALGRIND_H +#define __VALGRIND_H + + +/* ------------------------------------------------------------------ */ +/* VERSION NUMBER OF VALGRIND */ +/* ------------------------------------------------------------------ */ + +/* Specify Valgrind's version number, so that user code can + conditionally compile based on our version number. Note that these + were introduced at version 3.6 and so do not exist in version 3.5 + or earlier. The recommended way to use them to check for "version + X.Y or later" is (eg) + +#if defined(__VALGRIND_MAJOR__) && defined(__VALGRIND_MINOR__) \ + && (__VALGRIND_MAJOR__ > 3 \ + || (__VALGRIND_MAJOR__ == 3 && __VALGRIND_MINOR__ >= 6)) +*/ +#define __VALGRIND_MAJOR__ 3 +#define __VALGRIND_MINOR__ 17 + + +#include + +/* Nb: this file might be included in a file compiled with -ansi. So + we can't use C++ style "//" comments nor the "asm" keyword (instead + use "__asm__"). */ + +/* Derive some tags indicating what the target platform is. Note + that in this file we're using the compiler's CPP symbols for + identifying architectures, which are different to the ones we use + within the rest of Valgrind. Note, __powerpc__ is active for both + 32 and 64-bit PPC, whereas __powerpc64__ is only active for the + latter (on Linux, that is). + + Misc note: how to find out what's predefined in gcc by default: + gcc -Wp,-dM somefile.c +*/ +#undef PLAT_x86_darwin +#undef PLAT_amd64_darwin +#undef PLAT_x86_win32 +#undef PLAT_amd64_win64 +#undef PLAT_x86_linux +#undef PLAT_amd64_linux +#undef PLAT_ppc32_linux +#undef PLAT_ppc64be_linux +#undef PLAT_ppc64le_linux +#undef PLAT_arm_linux +#undef PLAT_arm64_linux +#undef PLAT_s390x_linux +#undef PLAT_mips32_linux +#undef PLAT_mips64_linux +#undef PLAT_nanomips_linux +#undef PLAT_x86_solaris +#undef PLAT_amd64_solaris + + +#if defined(__APPLE__) && defined(__i386__) +# define PLAT_x86_darwin 1 +#elif defined(__APPLE__) && defined(__x86_64__) +# define PLAT_amd64_darwin 1 +#elif (defined(__MINGW32__) && defined(__i386__)) \ + || defined(__CYGWIN32__) \ + || (defined(_WIN32) && defined(_M_IX86)) +# define PLAT_x86_win32 1 +#elif (defined(__MINGW32__) && defined(__x86_64__)) \ + || (defined(_WIN32) && defined(_M_X64)) +/* __MINGW32__ and _WIN32 are defined in 64 bit mode as well. */ +# define PLAT_amd64_win64 1 +#elif defined(__linux__) && defined(__i386__) +# define PLAT_x86_linux 1 +#elif defined(__linux__) && defined(__x86_64__) && !defined(__ILP32__) +# define PLAT_amd64_linux 1 +#elif defined(__linux__) && defined(__powerpc__) && !defined(__powerpc64__) +# define PLAT_ppc32_linux 1 +#elif defined(__linux__) && defined(__powerpc__) && defined(__powerpc64__) && _CALL_ELF != 2 +/* Big Endian uses ELF version 1 */ +# define PLAT_ppc64be_linux 1 +#elif defined(__linux__) && defined(__powerpc__) && defined(__powerpc64__) && _CALL_ELF == 2 +/* Little Endian uses ELF version 2 */ +# define PLAT_ppc64le_linux 1 +#elif defined(__linux__) && defined(__arm__) && !defined(__aarch64__) +# define PLAT_arm_linux 1 +#elif defined(__linux__) && defined(__aarch64__) && !defined(__arm__) +# define PLAT_arm64_linux 1 +#elif defined(__linux__) && defined(__s390__) && defined(__s390x__) +# define PLAT_s390x_linux 1 +#elif defined(__linux__) && defined(__mips__) && (__mips==64) +# define PLAT_mips64_linux 1 +#elif defined(__linux__) && defined(__mips__) && (__mips==32) +# define PLAT_mips32_linux 1 +#elif defined(__linux__) && defined(__nanomips__) +# define PLAT_nanomips_linux 1 +#elif defined(__sun) && defined(__i386__) +# define PLAT_x86_solaris 1 +#elif defined(__sun) && defined(__x86_64__) +# define PLAT_amd64_solaris 1 +#else +/* If we're not compiling for our target platform, don't generate + any inline asms. */ +# if !defined(NVALGRIND) +# define NVALGRIND 1 +# endif +#endif + + +/* ------------------------------------------------------------------ */ +/* ARCHITECTURE SPECIFICS for SPECIAL INSTRUCTIONS. There is nothing */ +/* in here of use to end-users -- skip to the next section. */ +/* ------------------------------------------------------------------ */ + +/* + * VALGRIND_DO_CLIENT_REQUEST(): a statement that invokes a Valgrind client + * request. Accepts both pointers and integers as arguments. + * + * VALGRIND_DO_CLIENT_REQUEST_STMT(): a statement that invokes a Valgrind + * client request that does not return a value. + + * VALGRIND_DO_CLIENT_REQUEST_EXPR(): a C expression that invokes a Valgrind + * client request and whose value equals the client request result. Accepts + * both pointers and integers as arguments. Note that such calls are not + * necessarily pure functions -- they may have side effects. + */ + +#define VALGRIND_DO_CLIENT_REQUEST(_zzq_rlval, _zzq_default, \ + _zzq_request, _zzq_arg1, _zzq_arg2, \ + _zzq_arg3, _zzq_arg4, _zzq_arg5) \ + do { (_zzq_rlval) = VALGRIND_DO_CLIENT_REQUEST_EXPR((_zzq_default), \ + (_zzq_request), (_zzq_arg1), (_zzq_arg2), \ + (_zzq_arg3), (_zzq_arg4), (_zzq_arg5)); } while (0) + +#define VALGRIND_DO_CLIENT_REQUEST_STMT(_zzq_request, _zzq_arg1, \ + _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5) \ + do { (void) VALGRIND_DO_CLIENT_REQUEST_EXPR(0, \ + (_zzq_request), (_zzq_arg1), (_zzq_arg2), \ + (_zzq_arg3), (_zzq_arg4), (_zzq_arg5)); } while (0) + +#if defined(NVALGRIND) + +/* Define NVALGRIND to completely remove the Valgrind magic sequence + from the compiled code (analogous to NDEBUG's effects on + assert()) */ +#define VALGRIND_DO_CLIENT_REQUEST_EXPR( \ + _zzq_default, _zzq_request, \ + _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5) \ + (_zzq_default) + +#else /* ! NVALGRIND */ + +/* The following defines the magic code sequences which the JITter + spots and handles magically. Don't look too closely at them as + they will rot your brain. + + The assembly code sequences for all architectures is in this one + file. This is because this file must be stand-alone, and we don't + want to have multiple files. + + For VALGRIND_DO_CLIENT_REQUEST, we must ensure that the default + value gets put in the return slot, so that everything works when + this is executed not under Valgrind. Args are passed in a memory + block, and so there's no intrinsic limit to the number that could + be passed, but it's currently five. + + The macro args are: + _zzq_rlval result lvalue + _zzq_default default value (result returned when running on real CPU) + _zzq_request request code + _zzq_arg1..5 request params + + The other two macros are used to support function wrapping, and are + a lot simpler. VALGRIND_GET_NR_CONTEXT returns the value of the + guest's NRADDR pseudo-register and whatever other information is + needed to safely run the call original from the wrapper: on + ppc64-linux, the R2 value at the divert point is also needed. This + information is abstracted into a user-visible type, OrigFn. + + VALGRIND_CALL_NOREDIR_* behaves the same as the following on the + guest, but guarantees that the branch instruction will not be + redirected: x86: call *%eax, amd64: call *%rax, ppc32/ppc64: + branch-and-link-to-r11. VALGRIND_CALL_NOREDIR is just text, not a + complete inline asm, since it needs to be combined with more magic + inline asm stuff to be useful. +*/ + +/* ----------------- x86-{linux,darwin,solaris} ---------------- */ + +#if defined(PLAT_x86_linux) || defined(PLAT_x86_darwin) \ + || (defined(PLAT_x86_win32) && defined(__GNUC__)) \ + || defined(PLAT_x86_solaris) + +typedef + struct { + unsigned int nraddr; /* where's the code? */ + } + OrigFn; + +#define __SPECIAL_INSTRUCTION_PREAMBLE \ + "roll $3, %%edi ; roll $13, %%edi\n\t" \ + "roll $29, %%edi ; roll $19, %%edi\n\t" + +#define VALGRIND_DO_CLIENT_REQUEST_EXPR( \ + _zzq_default, _zzq_request, \ + _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5) \ + __extension__ \ + ({volatile unsigned int _zzq_args[6]; \ + volatile unsigned int _zzq_result; \ + _zzq_args[0] = (unsigned int)(_zzq_request); \ + _zzq_args[1] = (unsigned int)(_zzq_arg1); \ + _zzq_args[2] = (unsigned int)(_zzq_arg2); \ + _zzq_args[3] = (unsigned int)(_zzq_arg3); \ + _zzq_args[4] = (unsigned int)(_zzq_arg4); \ + _zzq_args[5] = (unsigned int)(_zzq_arg5); \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + /* %EDX = client_request ( %EAX ) */ \ + "xchgl %%ebx,%%ebx" \ + : "=d" (_zzq_result) \ + : "a" (&_zzq_args[0]), "0" (_zzq_default) \ + : "cc", "memory" \ + ); \ + _zzq_result; \ + }) + +#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval) \ + { volatile OrigFn* _zzq_orig = &(_zzq_rlval); \ + volatile unsigned int __addr; \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + /* %EAX = guest_NRADDR */ \ + "xchgl %%ecx,%%ecx" \ + : "=a" (__addr) \ + : \ + : "cc", "memory" \ + ); \ + _zzq_orig->nraddr = __addr; \ + } + +#define VALGRIND_CALL_NOREDIR_EAX \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* call-noredir *%EAX */ \ + "xchgl %%edx,%%edx\n\t" + +#define VALGRIND_VEX_INJECT_IR() \ + do { \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + "xchgl %%edi,%%edi\n\t" \ + : : : "cc", "memory" \ + ); \ + } while (0) + +#endif /* PLAT_x86_linux || PLAT_x86_darwin || (PLAT_x86_win32 && __GNUC__) + || PLAT_x86_solaris */ + +/* ------------------------- x86-Win32 ------------------------- */ + +#if defined(PLAT_x86_win32) && !defined(__GNUC__) + +typedef + struct { + unsigned int nraddr; /* where's the code? */ + } + OrigFn; + +#if defined(_MSC_VER) + +#define __SPECIAL_INSTRUCTION_PREAMBLE \ + __asm rol edi, 3 __asm rol edi, 13 \ + __asm rol edi, 29 __asm rol edi, 19 + +#define VALGRIND_DO_CLIENT_REQUEST_EXPR( \ + _zzq_default, _zzq_request, \ + _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5) \ + valgrind_do_client_request_expr((uintptr_t)(_zzq_default), \ + (uintptr_t)(_zzq_request), (uintptr_t)(_zzq_arg1), \ + (uintptr_t)(_zzq_arg2), (uintptr_t)(_zzq_arg3), \ + (uintptr_t)(_zzq_arg4), (uintptr_t)(_zzq_arg5)) + +static __inline uintptr_t +valgrind_do_client_request_expr(uintptr_t _zzq_default, uintptr_t _zzq_request, + uintptr_t _zzq_arg1, uintptr_t _zzq_arg2, + uintptr_t _zzq_arg3, uintptr_t _zzq_arg4, + uintptr_t _zzq_arg5) +{ + volatile uintptr_t _zzq_args[6]; + volatile unsigned int _zzq_result; + _zzq_args[0] = (uintptr_t)(_zzq_request); + _zzq_args[1] = (uintptr_t)(_zzq_arg1); + _zzq_args[2] = (uintptr_t)(_zzq_arg2); + _zzq_args[3] = (uintptr_t)(_zzq_arg3); + _zzq_args[4] = (uintptr_t)(_zzq_arg4); + _zzq_args[5] = (uintptr_t)(_zzq_arg5); + __asm { __asm lea eax, _zzq_args __asm mov edx, _zzq_default + __SPECIAL_INSTRUCTION_PREAMBLE + /* %EDX = client_request ( %EAX ) */ + __asm xchg ebx,ebx + __asm mov _zzq_result, edx + } + return _zzq_result; +} + +#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval) \ + { volatile OrigFn* _zzq_orig = &(_zzq_rlval); \ + volatile unsigned int __addr; \ + __asm { __SPECIAL_INSTRUCTION_PREAMBLE \ + /* %EAX = guest_NRADDR */ \ + __asm xchg ecx,ecx \ + __asm mov __addr, eax \ + } \ + _zzq_orig->nraddr = __addr; \ + } + +#define VALGRIND_CALL_NOREDIR_EAX ERROR + +#define VALGRIND_VEX_INJECT_IR() \ + do { \ + __asm { __SPECIAL_INSTRUCTION_PREAMBLE \ + __asm xchg edi,edi \ + } \ + } while (0) + +#else +#error Unsupported compiler. +#endif + +#endif /* PLAT_x86_win32 */ + +/* ----------------- amd64-{linux,darwin,solaris} --------------- */ + +#if defined(PLAT_amd64_linux) || defined(PLAT_amd64_darwin) \ + || defined(PLAT_amd64_solaris) \ + || (defined(PLAT_amd64_win64) && defined(__GNUC__)) + +typedef + struct { + unsigned long int nraddr; /* where's the code? */ + } + OrigFn; + +#define __SPECIAL_INSTRUCTION_PREAMBLE \ + "rolq $3, %%rdi ; rolq $13, %%rdi\n\t" \ + "rolq $61, %%rdi ; rolq $51, %%rdi\n\t" + +#define VALGRIND_DO_CLIENT_REQUEST_EXPR( \ + _zzq_default, _zzq_request, \ + _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5) \ + __extension__ \ + ({ volatile unsigned long int _zzq_args[6]; \ + volatile unsigned long int _zzq_result; \ + _zzq_args[0] = (unsigned long int)(_zzq_request); \ + _zzq_args[1] = (unsigned long int)(_zzq_arg1); \ + _zzq_args[2] = (unsigned long int)(_zzq_arg2); \ + _zzq_args[3] = (unsigned long int)(_zzq_arg3); \ + _zzq_args[4] = (unsigned long int)(_zzq_arg4); \ + _zzq_args[5] = (unsigned long int)(_zzq_arg5); \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + /* %RDX = client_request ( %RAX ) */ \ + "xchgq %%rbx,%%rbx" \ + : "=d" (_zzq_result) \ + : "a" (&_zzq_args[0]), "0" (_zzq_default) \ + : "cc", "memory" \ + ); \ + _zzq_result; \ + }) + +#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval) \ + { volatile OrigFn* _zzq_orig = &(_zzq_rlval); \ + volatile unsigned long int __addr; \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + /* %RAX = guest_NRADDR */ \ + "xchgq %%rcx,%%rcx" \ + : "=a" (__addr) \ + : \ + : "cc", "memory" \ + ); \ + _zzq_orig->nraddr = __addr; \ + } + +#define VALGRIND_CALL_NOREDIR_RAX \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* call-noredir *%RAX */ \ + "xchgq %%rdx,%%rdx\n\t" + +#define VALGRIND_VEX_INJECT_IR() \ + do { \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + "xchgq %%rdi,%%rdi\n\t" \ + : : : "cc", "memory" \ + ); \ + } while (0) + +#endif /* PLAT_amd64_linux || PLAT_amd64_darwin || PLAT_amd64_solaris */ + +/* ------------------------- amd64-Win64 ------------------------- */ + +#if defined(PLAT_amd64_win64) && !defined(__GNUC__) + +#error Unsupported compiler. + +#endif /* PLAT_amd64_win64 */ + +/* ------------------------ ppc32-linux ------------------------ */ + +#if defined(PLAT_ppc32_linux) + +typedef + struct { + unsigned int nraddr; /* where's the code? */ + } + OrigFn; + +#define __SPECIAL_INSTRUCTION_PREAMBLE \ + "rlwinm 0,0,3,0,31 ; rlwinm 0,0,13,0,31\n\t" \ + "rlwinm 0,0,29,0,31 ; rlwinm 0,0,19,0,31\n\t" + +#define VALGRIND_DO_CLIENT_REQUEST_EXPR( \ + _zzq_default, _zzq_request, \ + _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5) \ + \ + __extension__ \ + ({ unsigned int _zzq_args[6]; \ + unsigned int _zzq_result; \ + unsigned int* _zzq_ptr; \ + _zzq_args[0] = (unsigned int)(_zzq_request); \ + _zzq_args[1] = (unsigned int)(_zzq_arg1); \ + _zzq_args[2] = (unsigned int)(_zzq_arg2); \ + _zzq_args[3] = (unsigned int)(_zzq_arg3); \ + _zzq_args[4] = (unsigned int)(_zzq_arg4); \ + _zzq_args[5] = (unsigned int)(_zzq_arg5); \ + _zzq_ptr = _zzq_args; \ + __asm__ volatile("mr 3,%1\n\t" /*default*/ \ + "mr 4,%2\n\t" /*ptr*/ \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* %R3 = client_request ( %R4 ) */ \ + "or 1,1,1\n\t" \ + "mr %0,3" /*result*/ \ + : "=b" (_zzq_result) \ + : "b" (_zzq_default), "b" (_zzq_ptr) \ + : "cc", "memory", "r3", "r4"); \ + _zzq_result; \ + }) + +#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval) \ + { volatile OrigFn* _zzq_orig = &(_zzq_rlval); \ + unsigned int __addr; \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + /* %R3 = guest_NRADDR */ \ + "or 2,2,2\n\t" \ + "mr %0,3" \ + : "=b" (__addr) \ + : \ + : "cc", "memory", "r3" \ + ); \ + _zzq_orig->nraddr = __addr; \ + } + +#define VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* branch-and-link-to-noredir *%R11 */ \ + "or 3,3,3\n\t" + +#define VALGRIND_VEX_INJECT_IR() \ + do { \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + "or 5,5,5\n\t" \ + ); \ + } while (0) + +#endif /* PLAT_ppc32_linux */ + +/* ------------------------ ppc64-linux ------------------------ */ + +#if defined(PLAT_ppc64be_linux) + +typedef + struct { + unsigned long int nraddr; /* where's the code? */ + unsigned long int r2; /* what tocptr do we need? */ + } + OrigFn; + +#define __SPECIAL_INSTRUCTION_PREAMBLE \ + "rotldi 0,0,3 ; rotldi 0,0,13\n\t" \ + "rotldi 0,0,61 ; rotldi 0,0,51\n\t" + +#define VALGRIND_DO_CLIENT_REQUEST_EXPR( \ + _zzq_default, _zzq_request, \ + _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5) \ + \ + __extension__ \ + ({ unsigned long int _zzq_args[6]; \ + unsigned long int _zzq_result; \ + unsigned long int* _zzq_ptr; \ + _zzq_args[0] = (unsigned long int)(_zzq_request); \ + _zzq_args[1] = (unsigned long int)(_zzq_arg1); \ + _zzq_args[2] = (unsigned long int)(_zzq_arg2); \ + _zzq_args[3] = (unsigned long int)(_zzq_arg3); \ + _zzq_args[4] = (unsigned long int)(_zzq_arg4); \ + _zzq_args[5] = (unsigned long int)(_zzq_arg5); \ + _zzq_ptr = _zzq_args; \ + __asm__ volatile("mr 3,%1\n\t" /*default*/ \ + "mr 4,%2\n\t" /*ptr*/ \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* %R3 = client_request ( %R4 ) */ \ + "or 1,1,1\n\t" \ + "mr %0,3" /*result*/ \ + : "=b" (_zzq_result) \ + : "b" (_zzq_default), "b" (_zzq_ptr) \ + : "cc", "memory", "r3", "r4"); \ + _zzq_result; \ + }) + +#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval) \ + { volatile OrigFn* _zzq_orig = &(_zzq_rlval); \ + unsigned long int __addr; \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + /* %R3 = guest_NRADDR */ \ + "or 2,2,2\n\t" \ + "mr %0,3" \ + : "=b" (__addr) \ + : \ + : "cc", "memory", "r3" \ + ); \ + _zzq_orig->nraddr = __addr; \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + /* %R3 = guest_NRADDR_GPR2 */ \ + "or 4,4,4\n\t" \ + "mr %0,3" \ + : "=b" (__addr) \ + : \ + : "cc", "memory", "r3" \ + ); \ + _zzq_orig->r2 = __addr; \ + } + +#define VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* branch-and-link-to-noredir *%R11 */ \ + "or 3,3,3\n\t" + +#define VALGRIND_VEX_INJECT_IR() \ + do { \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + "or 5,5,5\n\t" \ + ); \ + } while (0) + +#endif /* PLAT_ppc64be_linux */ + +#if defined(PLAT_ppc64le_linux) + +typedef + struct { + unsigned long int nraddr; /* where's the code? */ + unsigned long int r2; /* what tocptr do we need? */ + } + OrigFn; + +#define __SPECIAL_INSTRUCTION_PREAMBLE \ + "rotldi 0,0,3 ; rotldi 0,0,13\n\t" \ + "rotldi 0,0,61 ; rotldi 0,0,51\n\t" + +#define VALGRIND_DO_CLIENT_REQUEST_EXPR( \ + _zzq_default, _zzq_request, \ + _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5) \ + \ + __extension__ \ + ({ unsigned long int _zzq_args[6]; \ + unsigned long int _zzq_result; \ + unsigned long int* _zzq_ptr; \ + _zzq_args[0] = (unsigned long int)(_zzq_request); \ + _zzq_args[1] = (unsigned long int)(_zzq_arg1); \ + _zzq_args[2] = (unsigned long int)(_zzq_arg2); \ + _zzq_args[3] = (unsigned long int)(_zzq_arg3); \ + _zzq_args[4] = (unsigned long int)(_zzq_arg4); \ + _zzq_args[5] = (unsigned long int)(_zzq_arg5); \ + _zzq_ptr = _zzq_args; \ + __asm__ volatile("mr 3,%1\n\t" /*default*/ \ + "mr 4,%2\n\t" /*ptr*/ \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* %R3 = client_request ( %R4 ) */ \ + "or 1,1,1\n\t" \ + "mr %0,3" /*result*/ \ + : "=b" (_zzq_result) \ + : "b" (_zzq_default), "b" (_zzq_ptr) \ + : "cc", "memory", "r3", "r4"); \ + _zzq_result; \ + }) + +#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval) \ + { volatile OrigFn* _zzq_orig = &(_zzq_rlval); \ + unsigned long int __addr; \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + /* %R3 = guest_NRADDR */ \ + "or 2,2,2\n\t" \ + "mr %0,3" \ + : "=b" (__addr) \ + : \ + : "cc", "memory", "r3" \ + ); \ + _zzq_orig->nraddr = __addr; \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + /* %R3 = guest_NRADDR_GPR2 */ \ + "or 4,4,4\n\t" \ + "mr %0,3" \ + : "=b" (__addr) \ + : \ + : "cc", "memory", "r3" \ + ); \ + _zzq_orig->r2 = __addr; \ + } + +#define VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12 \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* branch-and-link-to-noredir *%R12 */ \ + "or 3,3,3\n\t" + +#define VALGRIND_VEX_INJECT_IR() \ + do { \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + "or 5,5,5\n\t" \ + ); \ + } while (0) + +#endif /* PLAT_ppc64le_linux */ + +/* ------------------------- arm-linux ------------------------- */ + +#if defined(PLAT_arm_linux) + +typedef + struct { + unsigned int nraddr; /* where's the code? */ + } + OrigFn; + +#define __SPECIAL_INSTRUCTION_PREAMBLE \ + "mov r12, r12, ror #3 ; mov r12, r12, ror #13 \n\t" \ + "mov r12, r12, ror #29 ; mov r12, r12, ror #19 \n\t" + +#define VALGRIND_DO_CLIENT_REQUEST_EXPR( \ + _zzq_default, _zzq_request, \ + _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5) \ + \ + __extension__ \ + ({volatile unsigned int _zzq_args[6]; \ + volatile unsigned int _zzq_result; \ + _zzq_args[0] = (unsigned int)(_zzq_request); \ + _zzq_args[1] = (unsigned int)(_zzq_arg1); \ + _zzq_args[2] = (unsigned int)(_zzq_arg2); \ + _zzq_args[3] = (unsigned int)(_zzq_arg3); \ + _zzq_args[4] = (unsigned int)(_zzq_arg4); \ + _zzq_args[5] = (unsigned int)(_zzq_arg5); \ + __asm__ volatile("mov r3, %1\n\t" /*default*/ \ + "mov r4, %2\n\t" /*ptr*/ \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* R3 = client_request ( R4 ) */ \ + "orr r10, r10, r10\n\t" \ + "mov %0, r3" /*result*/ \ + : "=r" (_zzq_result) \ + : "r" (_zzq_default), "r" (&_zzq_args[0]) \ + : "cc","memory", "r3", "r4"); \ + _zzq_result; \ + }) + +#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval) \ + { volatile OrigFn* _zzq_orig = &(_zzq_rlval); \ + unsigned int __addr; \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + /* R3 = guest_NRADDR */ \ + "orr r11, r11, r11\n\t" \ + "mov %0, r3" \ + : "=r" (__addr) \ + : \ + : "cc", "memory", "r3" \ + ); \ + _zzq_orig->nraddr = __addr; \ + } + +#define VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4 \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* branch-and-link-to-noredir *%R4 */ \ + "orr r12, r12, r12\n\t" + +#define VALGRIND_VEX_INJECT_IR() \ + do { \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + "orr r9, r9, r9\n\t" \ + : : : "cc", "memory" \ + ); \ + } while (0) + +#endif /* PLAT_arm_linux */ + +/* ------------------------ arm64-linux ------------------------- */ + +#if defined(PLAT_arm64_linux) + +typedef + struct { + unsigned long int nraddr; /* where's the code? */ + } + OrigFn; + +#define __SPECIAL_INSTRUCTION_PREAMBLE \ + "ror x12, x12, #3 ; ror x12, x12, #13 \n\t" \ + "ror x12, x12, #51 ; ror x12, x12, #61 \n\t" + +#define VALGRIND_DO_CLIENT_REQUEST_EXPR( \ + _zzq_default, _zzq_request, \ + _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5) \ + \ + __extension__ \ + ({volatile unsigned long int _zzq_args[6]; \ + volatile unsigned long int _zzq_result; \ + _zzq_args[0] = (unsigned long int)(_zzq_request); \ + _zzq_args[1] = (unsigned long int)(_zzq_arg1); \ + _zzq_args[2] = (unsigned long int)(_zzq_arg2); \ + _zzq_args[3] = (unsigned long int)(_zzq_arg3); \ + _zzq_args[4] = (unsigned long int)(_zzq_arg4); \ + _zzq_args[5] = (unsigned long int)(_zzq_arg5); \ + __asm__ volatile("mov x3, %1\n\t" /*default*/ \ + "mov x4, %2\n\t" /*ptr*/ \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* X3 = client_request ( X4 ) */ \ + "orr x10, x10, x10\n\t" \ + "mov %0, x3" /*result*/ \ + : "=r" (_zzq_result) \ + : "r" ((unsigned long int)(_zzq_default)), \ + "r" (&_zzq_args[0]) \ + : "cc","memory", "x3", "x4"); \ + _zzq_result; \ + }) + +#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval) \ + { volatile OrigFn* _zzq_orig = &(_zzq_rlval); \ + unsigned long int __addr; \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + /* X3 = guest_NRADDR */ \ + "orr x11, x11, x11\n\t" \ + "mov %0, x3" \ + : "=r" (__addr) \ + : \ + : "cc", "memory", "x3" \ + ); \ + _zzq_orig->nraddr = __addr; \ + } + +#define VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8 \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* branch-and-link-to-noredir X8 */ \ + "orr x12, x12, x12\n\t" + +#define VALGRIND_VEX_INJECT_IR() \ + do { \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + "orr x9, x9, x9\n\t" \ + : : : "cc", "memory" \ + ); \ + } while (0) + +#endif /* PLAT_arm64_linux */ + +/* ------------------------ s390x-linux ------------------------ */ + +#if defined(PLAT_s390x_linux) + +typedef + struct { + unsigned long int nraddr; /* where's the code? */ + } + OrigFn; + +/* __SPECIAL_INSTRUCTION_PREAMBLE will be used to identify Valgrind specific + * code. This detection is implemented in platform specific toIR.c + * (e.g. VEX/priv/guest_s390_decoder.c). + */ +#define __SPECIAL_INSTRUCTION_PREAMBLE \ + "lr 15,15\n\t" \ + "lr 1,1\n\t" \ + "lr 2,2\n\t" \ + "lr 3,3\n\t" + +#define __CLIENT_REQUEST_CODE "lr 2,2\n\t" +#define __GET_NR_CONTEXT_CODE "lr 3,3\n\t" +#define __CALL_NO_REDIR_CODE "lr 4,4\n\t" +#define __VEX_INJECT_IR_CODE "lr 5,5\n\t" + +#define VALGRIND_DO_CLIENT_REQUEST_EXPR( \ + _zzq_default, _zzq_request, \ + _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5) \ + __extension__ \ + ({volatile unsigned long int _zzq_args[6]; \ + volatile unsigned long int _zzq_result; \ + _zzq_args[0] = (unsigned long int)(_zzq_request); \ + _zzq_args[1] = (unsigned long int)(_zzq_arg1); \ + _zzq_args[2] = (unsigned long int)(_zzq_arg2); \ + _zzq_args[3] = (unsigned long int)(_zzq_arg3); \ + _zzq_args[4] = (unsigned long int)(_zzq_arg4); \ + _zzq_args[5] = (unsigned long int)(_zzq_arg5); \ + __asm__ volatile(/* r2 = args */ \ + "lgr 2,%1\n\t" \ + /* r3 = default */ \ + "lgr 3,%2\n\t" \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + __CLIENT_REQUEST_CODE \ + /* results = r3 */ \ + "lgr %0, 3\n\t" \ + : "=d" (_zzq_result) \ + : "a" (&_zzq_args[0]), "0" (_zzq_default) \ + : "cc", "2", "3", "memory" \ + ); \ + _zzq_result; \ + }) + +#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval) \ + { volatile OrigFn* _zzq_orig = &(_zzq_rlval); \ + volatile unsigned long int __addr; \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + __GET_NR_CONTEXT_CODE \ + "lgr %0, 3\n\t" \ + : "=a" (__addr) \ + : \ + : "cc", "3", "memory" \ + ); \ + _zzq_orig->nraddr = __addr; \ + } + +#define VALGRIND_CALL_NOREDIR_R1 \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + __CALL_NO_REDIR_CODE + +#define VALGRIND_VEX_INJECT_IR() \ + do { \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + __VEX_INJECT_IR_CODE); \ + } while (0) + +#endif /* PLAT_s390x_linux */ + +/* ------------------------- mips32-linux ---------------- */ + +#if defined(PLAT_mips32_linux) + +typedef + struct { + unsigned int nraddr; /* where's the code? */ + } + OrigFn; + +/* .word 0x342 + * .word 0x742 + * .word 0xC2 + * .word 0x4C2*/ +#define __SPECIAL_INSTRUCTION_PREAMBLE \ + "srl $0, $0, 13\n\t" \ + "srl $0, $0, 29\n\t" \ + "srl $0, $0, 3\n\t" \ + "srl $0, $0, 19\n\t" + +#define VALGRIND_DO_CLIENT_REQUEST_EXPR( \ + _zzq_default, _zzq_request, \ + _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5) \ + __extension__ \ + ({ volatile unsigned int _zzq_args[6]; \ + volatile unsigned int _zzq_result; \ + _zzq_args[0] = (unsigned int)(_zzq_request); \ + _zzq_args[1] = (unsigned int)(_zzq_arg1); \ + _zzq_args[2] = (unsigned int)(_zzq_arg2); \ + _zzq_args[3] = (unsigned int)(_zzq_arg3); \ + _zzq_args[4] = (unsigned int)(_zzq_arg4); \ + _zzq_args[5] = (unsigned int)(_zzq_arg5); \ + __asm__ volatile("move $11, %1\n\t" /*default*/ \ + "move $12, %2\n\t" /*ptr*/ \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* T3 = client_request ( T4 ) */ \ + "or $13, $13, $13\n\t" \ + "move %0, $11\n\t" /*result*/ \ + : "=r" (_zzq_result) \ + : "r" (_zzq_default), "r" (&_zzq_args[0]) \ + : "$11", "$12", "memory"); \ + _zzq_result; \ + }) + +#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval) \ + { volatile OrigFn* _zzq_orig = &(_zzq_rlval); \ + volatile unsigned int __addr; \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + /* %t9 = guest_NRADDR */ \ + "or $14, $14, $14\n\t" \ + "move %0, $11" /*result*/ \ + : "=r" (__addr) \ + : \ + : "$11" \ + ); \ + _zzq_orig->nraddr = __addr; \ + } + +#define VALGRIND_CALL_NOREDIR_T9 \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* call-noredir *%t9 */ \ + "or $15, $15, $15\n\t" + +#define VALGRIND_VEX_INJECT_IR() \ + do { \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + "or $11, $11, $11\n\t" \ + ); \ + } while (0) + + +#endif /* PLAT_mips32_linux */ + +/* ------------------------- mips64-linux ---------------- */ + +#if defined(PLAT_mips64_linux) + +typedef + struct { + unsigned long nraddr; /* where's the code? */ + } + OrigFn; + +/* dsll $0,$0, 3 + * dsll $0,$0, 13 + * dsll $0,$0, 29 + * dsll $0,$0, 19*/ +#define __SPECIAL_INSTRUCTION_PREAMBLE \ + "dsll $0,$0, 3 ; dsll $0,$0,13\n\t" \ + "dsll $0,$0,29 ; dsll $0,$0,19\n\t" + +#define VALGRIND_DO_CLIENT_REQUEST_EXPR( \ + _zzq_default, _zzq_request, \ + _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5) \ + __extension__ \ + ({ volatile unsigned long int _zzq_args[6]; \ + volatile unsigned long int _zzq_result; \ + _zzq_args[0] = (unsigned long int)(_zzq_request); \ + _zzq_args[1] = (unsigned long int)(_zzq_arg1); \ + _zzq_args[2] = (unsigned long int)(_zzq_arg2); \ + _zzq_args[3] = (unsigned long int)(_zzq_arg3); \ + _zzq_args[4] = (unsigned long int)(_zzq_arg4); \ + _zzq_args[5] = (unsigned long int)(_zzq_arg5); \ + __asm__ volatile("move $11, %1\n\t" /*default*/ \ + "move $12, %2\n\t" /*ptr*/ \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* $11 = client_request ( $12 ) */ \ + "or $13, $13, $13\n\t" \ + "move %0, $11\n\t" /*result*/ \ + : "=r" (_zzq_result) \ + : "r" (_zzq_default), "r" (&_zzq_args[0]) \ + : "$11", "$12", "memory"); \ + _zzq_result; \ + }) + +#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval) \ + { volatile OrigFn* _zzq_orig = &(_zzq_rlval); \ + volatile unsigned long int __addr; \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + /* $11 = guest_NRADDR */ \ + "or $14, $14, $14\n\t" \ + "move %0, $11" /*result*/ \ + : "=r" (__addr) \ + : \ + : "$11"); \ + _zzq_orig->nraddr = __addr; \ + } + +#define VALGRIND_CALL_NOREDIR_T9 \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* call-noredir $25 */ \ + "or $15, $15, $15\n\t" + +#define VALGRIND_VEX_INJECT_IR() \ + do { \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + "or $11, $11, $11\n\t" \ + ); \ + } while (0) + +#endif /* PLAT_mips64_linux */ + +#if defined(PLAT_nanomips_linux) + +typedef + struct { + unsigned int nraddr; /* where's the code? */ + } + OrigFn; +/* + 8000 c04d srl zero, zero, 13 + 8000 c05d srl zero, zero, 29 + 8000 c043 srl zero, zero, 3 + 8000 c053 srl zero, zero, 19 +*/ + +#define __SPECIAL_INSTRUCTION_PREAMBLE "srl[32] $zero, $zero, 13 \n\t" \ + "srl[32] $zero, $zero, 29 \n\t" \ + "srl[32] $zero, $zero, 3 \n\t" \ + "srl[32] $zero, $zero, 19 \n\t" + +#define VALGRIND_DO_CLIENT_REQUEST_EXPR( \ + _zzq_default, _zzq_request, \ + _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5) \ + __extension__ \ + ({ volatile unsigned int _zzq_args[6]; \ + volatile unsigned int _zzq_result; \ + _zzq_args[0] = (unsigned int)(_zzq_request); \ + _zzq_args[1] = (unsigned int)(_zzq_arg1); \ + _zzq_args[2] = (unsigned int)(_zzq_arg2); \ + _zzq_args[3] = (unsigned int)(_zzq_arg3); \ + _zzq_args[4] = (unsigned int)(_zzq_arg4); \ + _zzq_args[5] = (unsigned int)(_zzq_arg5); \ + __asm__ volatile("move $a7, %1\n\t" /* default */ \ + "move $t0, %2\n\t" /* ptr */ \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* $a7 = client_request( $t0 ) */ \ + "or[32] $t0, $t0, $t0\n\t" \ + "move %0, $a7\n\t" /* result */ \ + : "=r" (_zzq_result) \ + : "r" (_zzq_default), "r" (&_zzq_args[0]) \ + : "$a7", "$t0", "memory"); \ + _zzq_result; \ + }) + +#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval) \ + { volatile OrigFn* _zzq_orig = &(_zzq_rlval); \ + volatile unsigned long int __addr; \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + /* $a7 = guest_NRADDR */ \ + "or[32] $t1, $t1, $t1\n\t" \ + "move %0, $a7" /*result*/ \ + : "=r" (__addr) \ + : \ + : "$a7"); \ + _zzq_orig->nraddr = __addr; \ + } + +#define VALGRIND_CALL_NOREDIR_T9 \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* call-noredir $25 */ \ + "or[32] $t2, $t2, $t2\n\t" + +#define VALGRIND_VEX_INJECT_IR() \ + do { \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + "or[32] $t3, $t3, $t3\n\t" \ + ); \ + } while (0) + +#endif +/* Insert assembly code for other platforms here... */ + +#endif /* NVALGRIND */ + + +/* ------------------------------------------------------------------ */ +/* PLATFORM SPECIFICS for FUNCTION WRAPPING. This is all very */ +/* ugly. It's the least-worst tradeoff I can think of. */ +/* ------------------------------------------------------------------ */ + +/* This section defines magic (a.k.a appalling-hack) macros for doing + guaranteed-no-redirection macros, so as to get from function + wrappers to the functions they are wrapping. The whole point is to + construct standard call sequences, but to do the call itself with a + special no-redirect call pseudo-instruction that the JIT + understands and handles specially. This section is long and + repetitious, and I can't see a way to make it shorter. + + The naming scheme is as follows: + + CALL_FN_{W,v}_{v,W,WW,WWW,WWWW,5W,6W,7W,etc} + + 'W' stands for "word" and 'v' for "void". Hence there are + different macros for calling arity 0, 1, 2, 3, 4, etc, functions, + and for each, the possibility of returning a word-typed result, or + no result. +*/ + +/* Use these to write the name of your wrapper. NOTE: duplicates + VG_WRAP_FUNCTION_Z{U,Z} in pub_tool_redir.h. NOTE also: inserts + the default behaviour equivalance class tag "0000" into the name. + See pub_tool_redir.h for details -- normally you don't need to + think about this, though. */ + +/* Use an extra level of macroisation so as to ensure the soname/fnname + args are fully macro-expanded before pasting them together. */ +#define VG_CONCAT4(_aa,_bb,_cc,_dd) _aa##_bb##_cc##_dd + +#define I_WRAP_SONAME_FNNAME_ZU(soname,fnname) \ + VG_CONCAT4(_vgw00000ZU_,soname,_,fnname) + +#define I_WRAP_SONAME_FNNAME_ZZ(soname,fnname) \ + VG_CONCAT4(_vgw00000ZZ_,soname,_,fnname) + +/* Use this macro from within a wrapper function to collect the + context (address and possibly other info) of the original function. + Once you have that you can then use it in one of the CALL_FN_ + macros. The type of the argument _lval is OrigFn. */ +#define VALGRIND_GET_ORIG_FN(_lval) VALGRIND_GET_NR_CONTEXT(_lval) + +/* Also provide end-user facilities for function replacement, rather + than wrapping. A replacement function differs from a wrapper in + that it has no way to get hold of the original function being + called, and hence no way to call onwards to it. In a replacement + function, VALGRIND_GET_ORIG_FN always returns zero. */ + +#define I_REPLACE_SONAME_FNNAME_ZU(soname,fnname) \ + VG_CONCAT4(_vgr00000ZU_,soname,_,fnname) + +#define I_REPLACE_SONAME_FNNAME_ZZ(soname,fnname) \ + VG_CONCAT4(_vgr00000ZZ_,soname,_,fnname) + +/* Derivatives of the main macros below, for calling functions + returning void. */ + +#define CALL_FN_v_v(fnptr) \ + do { volatile unsigned long _junk; \ + CALL_FN_W_v(_junk,fnptr); } while (0) + +#define CALL_FN_v_W(fnptr, arg1) \ + do { volatile unsigned long _junk; \ + CALL_FN_W_W(_junk,fnptr,arg1); } while (0) + +#define CALL_FN_v_WW(fnptr, arg1,arg2) \ + do { volatile unsigned long _junk; \ + CALL_FN_W_WW(_junk,fnptr,arg1,arg2); } while (0) + +#define CALL_FN_v_WWW(fnptr, arg1,arg2,arg3) \ + do { volatile unsigned long _junk; \ + CALL_FN_W_WWW(_junk,fnptr,arg1,arg2,arg3); } while (0) + +#define CALL_FN_v_WWWW(fnptr, arg1,arg2,arg3,arg4) \ + do { volatile unsigned long _junk; \ + CALL_FN_W_WWWW(_junk,fnptr,arg1,arg2,arg3,arg4); } while (0) + +#define CALL_FN_v_5W(fnptr, arg1,arg2,arg3,arg4,arg5) \ + do { volatile unsigned long _junk; \ + CALL_FN_W_5W(_junk,fnptr,arg1,arg2,arg3,arg4,arg5); } while (0) + +#define CALL_FN_v_6W(fnptr, arg1,arg2,arg3,arg4,arg5,arg6) \ + do { volatile unsigned long _junk; \ + CALL_FN_W_6W(_junk,fnptr,arg1,arg2,arg3,arg4,arg5,arg6); } while (0) + +#define CALL_FN_v_7W(fnptr, arg1,arg2,arg3,arg4,arg5,arg6,arg7) \ + do { volatile unsigned long _junk; \ + CALL_FN_W_7W(_junk,fnptr,arg1,arg2,arg3,arg4,arg5,arg6,arg7); } while (0) + +/* ----------------- x86-{linux,darwin,solaris} ---------------- */ + +#if defined(PLAT_x86_linux) || defined(PLAT_x86_darwin) \ + || defined(PLAT_x86_solaris) + +/* These regs are trashed by the hidden call. No need to mention eax + as gcc can already see that, plus causes gcc to bomb. */ +#define __CALLER_SAVED_REGS /*"eax"*/ "ecx", "edx" + +/* Macros to save and align the stack before making a function + call and restore it afterwards as gcc may not keep the stack + pointer aligned if it doesn't realise calls are being made + to other functions. */ + +#define VALGRIND_ALIGN_STACK \ + "movl %%esp,%%edi\n\t" \ + "andl $0xfffffff0,%%esp\n\t" +#define VALGRIND_RESTORE_STACK \ + "movl %%edi,%%esp\n\t" + +/* These CALL_FN_ macros assume that on x86-linux, sizeof(unsigned + long) == 4. */ + +#define CALL_FN_W_v(lval, orig) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[1]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "movl (%%eax), %%eax\n\t" /* target->%eax */ \ + VALGRIND_CALL_NOREDIR_EAX \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_W(lval, orig, arg1) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[2]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "subl $12, %%esp\n\t" \ + "pushl 4(%%eax)\n\t" \ + "movl (%%eax), %%eax\n\t" /* target->%eax */ \ + VALGRIND_CALL_NOREDIR_EAX \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WW(lval, orig, arg1,arg2) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "subl $8, %%esp\n\t" \ + "pushl 8(%%eax)\n\t" \ + "pushl 4(%%eax)\n\t" \ + "movl (%%eax), %%eax\n\t" /* target->%eax */ \ + VALGRIND_CALL_NOREDIR_EAX \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWW(lval, orig, arg1,arg2,arg3) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[4]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "subl $4, %%esp\n\t" \ + "pushl 12(%%eax)\n\t" \ + "pushl 8(%%eax)\n\t" \ + "pushl 4(%%eax)\n\t" \ + "movl (%%eax), %%eax\n\t" /* target->%eax */ \ + VALGRIND_CALL_NOREDIR_EAX \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWWW(lval, orig, arg1,arg2,arg3,arg4) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[5]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "pushl 16(%%eax)\n\t" \ + "pushl 12(%%eax)\n\t" \ + "pushl 8(%%eax)\n\t" \ + "pushl 4(%%eax)\n\t" \ + "movl (%%eax), %%eax\n\t" /* target->%eax */ \ + VALGRIND_CALL_NOREDIR_EAX \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_5W(lval, orig, arg1,arg2,arg3,arg4,arg5) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[6]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "subl $12, %%esp\n\t" \ + "pushl 20(%%eax)\n\t" \ + "pushl 16(%%eax)\n\t" \ + "pushl 12(%%eax)\n\t" \ + "pushl 8(%%eax)\n\t" \ + "pushl 4(%%eax)\n\t" \ + "movl (%%eax), %%eax\n\t" /* target->%eax */ \ + VALGRIND_CALL_NOREDIR_EAX \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_6W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[7]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "subl $8, %%esp\n\t" \ + "pushl 24(%%eax)\n\t" \ + "pushl 20(%%eax)\n\t" \ + "pushl 16(%%eax)\n\t" \ + "pushl 12(%%eax)\n\t" \ + "pushl 8(%%eax)\n\t" \ + "pushl 4(%%eax)\n\t" \ + "movl (%%eax), %%eax\n\t" /* target->%eax */ \ + VALGRIND_CALL_NOREDIR_EAX \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_7W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[8]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "subl $4, %%esp\n\t" \ + "pushl 28(%%eax)\n\t" \ + "pushl 24(%%eax)\n\t" \ + "pushl 20(%%eax)\n\t" \ + "pushl 16(%%eax)\n\t" \ + "pushl 12(%%eax)\n\t" \ + "pushl 8(%%eax)\n\t" \ + "pushl 4(%%eax)\n\t" \ + "movl (%%eax), %%eax\n\t" /* target->%eax */ \ + VALGRIND_CALL_NOREDIR_EAX \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_8W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[9]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "pushl 32(%%eax)\n\t" \ + "pushl 28(%%eax)\n\t" \ + "pushl 24(%%eax)\n\t" \ + "pushl 20(%%eax)\n\t" \ + "pushl 16(%%eax)\n\t" \ + "pushl 12(%%eax)\n\t" \ + "pushl 8(%%eax)\n\t" \ + "pushl 4(%%eax)\n\t" \ + "movl (%%eax), %%eax\n\t" /* target->%eax */ \ + VALGRIND_CALL_NOREDIR_EAX \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_9W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[10]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "subl $12, %%esp\n\t" \ + "pushl 36(%%eax)\n\t" \ + "pushl 32(%%eax)\n\t" \ + "pushl 28(%%eax)\n\t" \ + "pushl 24(%%eax)\n\t" \ + "pushl 20(%%eax)\n\t" \ + "pushl 16(%%eax)\n\t" \ + "pushl 12(%%eax)\n\t" \ + "pushl 8(%%eax)\n\t" \ + "pushl 4(%%eax)\n\t" \ + "movl (%%eax), %%eax\n\t" /* target->%eax */ \ + VALGRIND_CALL_NOREDIR_EAX \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_10W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[11]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "subl $8, %%esp\n\t" \ + "pushl 40(%%eax)\n\t" \ + "pushl 36(%%eax)\n\t" \ + "pushl 32(%%eax)\n\t" \ + "pushl 28(%%eax)\n\t" \ + "pushl 24(%%eax)\n\t" \ + "pushl 20(%%eax)\n\t" \ + "pushl 16(%%eax)\n\t" \ + "pushl 12(%%eax)\n\t" \ + "pushl 8(%%eax)\n\t" \ + "pushl 4(%%eax)\n\t" \ + "movl (%%eax), %%eax\n\t" /* target->%eax */ \ + VALGRIND_CALL_NOREDIR_EAX \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_11W(lval, orig, arg1,arg2,arg3,arg4,arg5, \ + arg6,arg7,arg8,arg9,arg10, \ + arg11) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[12]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + _argvec[11] = (unsigned long)(arg11); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "subl $4, %%esp\n\t" \ + "pushl 44(%%eax)\n\t" \ + "pushl 40(%%eax)\n\t" \ + "pushl 36(%%eax)\n\t" \ + "pushl 32(%%eax)\n\t" \ + "pushl 28(%%eax)\n\t" \ + "pushl 24(%%eax)\n\t" \ + "pushl 20(%%eax)\n\t" \ + "pushl 16(%%eax)\n\t" \ + "pushl 12(%%eax)\n\t" \ + "pushl 8(%%eax)\n\t" \ + "pushl 4(%%eax)\n\t" \ + "movl (%%eax), %%eax\n\t" /* target->%eax */ \ + VALGRIND_CALL_NOREDIR_EAX \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_12W(lval, orig, arg1,arg2,arg3,arg4,arg5, \ + arg6,arg7,arg8,arg9,arg10, \ + arg11,arg12) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[13]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + _argvec[11] = (unsigned long)(arg11); \ + _argvec[12] = (unsigned long)(arg12); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "pushl 48(%%eax)\n\t" \ + "pushl 44(%%eax)\n\t" \ + "pushl 40(%%eax)\n\t" \ + "pushl 36(%%eax)\n\t" \ + "pushl 32(%%eax)\n\t" \ + "pushl 28(%%eax)\n\t" \ + "pushl 24(%%eax)\n\t" \ + "pushl 20(%%eax)\n\t" \ + "pushl 16(%%eax)\n\t" \ + "pushl 12(%%eax)\n\t" \ + "pushl 8(%%eax)\n\t" \ + "pushl 4(%%eax)\n\t" \ + "movl (%%eax), %%eax\n\t" /* target->%eax */ \ + VALGRIND_CALL_NOREDIR_EAX \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#endif /* PLAT_x86_linux || PLAT_x86_darwin || PLAT_x86_solaris */ + +/* ---------------- amd64-{linux,darwin,solaris} --------------- */ + +#if defined(PLAT_amd64_linux) || defined(PLAT_amd64_darwin) \ + || defined(PLAT_amd64_solaris) + +/* ARGREGS: rdi rsi rdx rcx r8 r9 (the rest on stack in R-to-L order) */ + +/* These regs are trashed by the hidden call. */ +#define __CALLER_SAVED_REGS /*"rax",*/ "rcx", "rdx", "rsi", \ + "rdi", "r8", "r9", "r10", "r11" + +/* This is all pretty complex. It's so as to make stack unwinding + work reliably. See bug 243270. The basic problem is the sub and + add of 128 of %rsp in all of the following macros. If gcc believes + the CFA is in %rsp, then unwinding may fail, because what's at the + CFA is not what gcc "expected" when it constructs the CFIs for the + places where the macros are instantiated. + + But we can't just add a CFI annotation to increase the CFA offset + by 128, to match the sub of 128 from %rsp, because we don't know + whether gcc has chosen %rsp as the CFA at that point, or whether it + has chosen some other register (eg, %rbp). In the latter case, + adding a CFI annotation to change the CFA offset is simply wrong. + + So the solution is to get hold of the CFA using + __builtin_dwarf_cfa(), put it in a known register, and add a + CFI annotation to say what the register is. We choose %rbp for + this (perhaps perversely), because: + + (1) %rbp is already subject to unwinding. If a new register was + chosen then the unwinder would have to unwind it in all stack + traces, which is expensive, and + + (2) %rbp is already subject to precise exception updates in the + JIT. If a new register was chosen, we'd have to have precise + exceptions for it too, which reduces performance of the + generated code. + + However .. one extra complication. We can't just whack the result + of __builtin_dwarf_cfa() into %rbp and then add %rbp to the + list of trashed registers at the end of the inline assembly + fragments; gcc won't allow %rbp to appear in that list. Hence + instead we need to stash %rbp in %r15 for the duration of the asm, + and say that %r15 is trashed instead. gcc seems happy to go with + that. + + Oh .. and this all needs to be conditionalised so that it is + unchanged from before this commit, when compiled with older gccs + that don't support __builtin_dwarf_cfa. Furthermore, since + this header file is freestanding, it has to be independent of + config.h, and so the following conditionalisation cannot depend on + configure time checks. + + Although it's not clear from + 'defined(__GNUC__) && defined(__GCC_HAVE_DWARF2_CFI_ASM)', + this expression excludes Darwin. + .cfi directives in Darwin assembly appear to be completely + different and I haven't investigated how they work. + + For even more entertainment value, note we have to use the + completely undocumented __builtin_dwarf_cfa(), which appears to + really compute the CFA, whereas __builtin_frame_address(0) claims + to but actually doesn't. See + https://bugs.kde.org/show_bug.cgi?id=243270#c47 +*/ +#if defined(__GNUC__) && defined(__GCC_HAVE_DWARF2_CFI_ASM) +# define __FRAME_POINTER \ + ,"r"(__builtin_dwarf_cfa()) +# define VALGRIND_CFI_PROLOGUE \ + "movq %%rbp, %%r15\n\t" \ + "movq %2, %%rbp\n\t" \ + ".cfi_remember_state\n\t" \ + ".cfi_def_cfa rbp, 0\n\t" +# define VALGRIND_CFI_EPILOGUE \ + "movq %%r15, %%rbp\n\t" \ + ".cfi_restore_state\n\t" +#else +# define __FRAME_POINTER +# define VALGRIND_CFI_PROLOGUE +# define VALGRIND_CFI_EPILOGUE +#endif + +/* Macros to save and align the stack before making a function + call and restore it afterwards as gcc may not keep the stack + pointer aligned if it doesn't realise calls are being made + to other functions. */ + +#define VALGRIND_ALIGN_STACK \ + "movq %%rsp,%%r14\n\t" \ + "andq $0xfffffffffffffff0,%%rsp\n\t" +#define VALGRIND_RESTORE_STACK \ + "movq %%r14,%%rsp\n\t" + +/* These CALL_FN_ macros assume that on amd64-linux, sizeof(unsigned + long) == 8. */ + +/* NB 9 Sept 07. There is a nasty kludge here in all these CALL_FN_ + macros. In order not to trash the stack redzone, we need to drop + %rsp by 128 before the hidden call, and restore afterwards. The + nastyness is that it is only by luck that the stack still appears + to be unwindable during the hidden call - since then the behaviour + of any routine using this macro does not match what the CFI data + says. Sigh. + + Why is this important? Imagine that a wrapper has a stack + allocated local, and passes to the hidden call, a pointer to it. + Because gcc does not know about the hidden call, it may allocate + that local in the redzone. Unfortunately the hidden call may then + trash it before it comes to use it. So we must step clear of the + redzone, for the duration of the hidden call, to make it safe. + + Probably the same problem afflicts the other redzone-style ABIs too + (ppc64-linux); but for those, the stack is + self describing (none of this CFI nonsense) so at least messing + with the stack pointer doesn't give a danger of non-unwindable + stack. */ + +#define CALL_FN_W_v(lval, orig) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[1]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + VALGRIND_ALIGN_STACK \ + "subq $128,%%rsp\n\t" \ + "movq (%%rax), %%rax\n\t" /* target->%rax */ \ + VALGRIND_CALL_NOREDIR_RAX \ + VALGRIND_RESTORE_STACK \ + VALGRIND_CFI_EPILOGUE \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_W(lval, orig, arg1) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[2]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + VALGRIND_ALIGN_STACK \ + "subq $128,%%rsp\n\t" \ + "movq 8(%%rax), %%rdi\n\t" \ + "movq (%%rax), %%rax\n\t" /* target->%rax */ \ + VALGRIND_CALL_NOREDIR_RAX \ + VALGRIND_RESTORE_STACK \ + VALGRIND_CFI_EPILOGUE \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WW(lval, orig, arg1,arg2) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + VALGRIND_ALIGN_STACK \ + "subq $128,%%rsp\n\t" \ + "movq 16(%%rax), %%rsi\n\t" \ + "movq 8(%%rax), %%rdi\n\t" \ + "movq (%%rax), %%rax\n\t" /* target->%rax */ \ + VALGRIND_CALL_NOREDIR_RAX \ + VALGRIND_RESTORE_STACK \ + VALGRIND_CFI_EPILOGUE \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWW(lval, orig, arg1,arg2,arg3) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[4]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + VALGRIND_ALIGN_STACK \ + "subq $128,%%rsp\n\t" \ + "movq 24(%%rax), %%rdx\n\t" \ + "movq 16(%%rax), %%rsi\n\t" \ + "movq 8(%%rax), %%rdi\n\t" \ + "movq (%%rax), %%rax\n\t" /* target->%rax */ \ + VALGRIND_CALL_NOREDIR_RAX \ + VALGRIND_RESTORE_STACK \ + VALGRIND_CFI_EPILOGUE \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWWW(lval, orig, arg1,arg2,arg3,arg4) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[5]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + VALGRIND_ALIGN_STACK \ + "subq $128,%%rsp\n\t" \ + "movq 32(%%rax), %%rcx\n\t" \ + "movq 24(%%rax), %%rdx\n\t" \ + "movq 16(%%rax), %%rsi\n\t" \ + "movq 8(%%rax), %%rdi\n\t" \ + "movq (%%rax), %%rax\n\t" /* target->%rax */ \ + VALGRIND_CALL_NOREDIR_RAX \ + VALGRIND_RESTORE_STACK \ + VALGRIND_CFI_EPILOGUE \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_5W(lval, orig, arg1,arg2,arg3,arg4,arg5) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[6]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + VALGRIND_ALIGN_STACK \ + "subq $128,%%rsp\n\t" \ + "movq 40(%%rax), %%r8\n\t" \ + "movq 32(%%rax), %%rcx\n\t" \ + "movq 24(%%rax), %%rdx\n\t" \ + "movq 16(%%rax), %%rsi\n\t" \ + "movq 8(%%rax), %%rdi\n\t" \ + "movq (%%rax), %%rax\n\t" /* target->%rax */ \ + VALGRIND_CALL_NOREDIR_RAX \ + VALGRIND_RESTORE_STACK \ + VALGRIND_CFI_EPILOGUE \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_6W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[7]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + VALGRIND_ALIGN_STACK \ + "subq $128,%%rsp\n\t" \ + "movq 48(%%rax), %%r9\n\t" \ + "movq 40(%%rax), %%r8\n\t" \ + "movq 32(%%rax), %%rcx\n\t" \ + "movq 24(%%rax), %%rdx\n\t" \ + "movq 16(%%rax), %%rsi\n\t" \ + "movq 8(%%rax), %%rdi\n\t" \ + "movq (%%rax), %%rax\n\t" /* target->%rax */ \ + VALGRIND_CALL_NOREDIR_RAX \ + VALGRIND_RESTORE_STACK \ + VALGRIND_CFI_EPILOGUE \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_7W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[8]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + VALGRIND_ALIGN_STACK \ + "subq $136,%%rsp\n\t" \ + "pushq 56(%%rax)\n\t" \ + "movq 48(%%rax), %%r9\n\t" \ + "movq 40(%%rax), %%r8\n\t" \ + "movq 32(%%rax), %%rcx\n\t" \ + "movq 24(%%rax), %%rdx\n\t" \ + "movq 16(%%rax), %%rsi\n\t" \ + "movq 8(%%rax), %%rdi\n\t" \ + "movq (%%rax), %%rax\n\t" /* target->%rax */ \ + VALGRIND_CALL_NOREDIR_RAX \ + VALGRIND_RESTORE_STACK \ + VALGRIND_CFI_EPILOGUE \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_8W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[9]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + VALGRIND_ALIGN_STACK \ + "subq $128,%%rsp\n\t" \ + "pushq 64(%%rax)\n\t" \ + "pushq 56(%%rax)\n\t" \ + "movq 48(%%rax), %%r9\n\t" \ + "movq 40(%%rax), %%r8\n\t" \ + "movq 32(%%rax), %%rcx\n\t" \ + "movq 24(%%rax), %%rdx\n\t" \ + "movq 16(%%rax), %%rsi\n\t" \ + "movq 8(%%rax), %%rdi\n\t" \ + "movq (%%rax), %%rax\n\t" /* target->%rax */ \ + VALGRIND_CALL_NOREDIR_RAX \ + VALGRIND_RESTORE_STACK \ + VALGRIND_CFI_EPILOGUE \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_9W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[10]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + VALGRIND_ALIGN_STACK \ + "subq $136,%%rsp\n\t" \ + "pushq 72(%%rax)\n\t" \ + "pushq 64(%%rax)\n\t" \ + "pushq 56(%%rax)\n\t" \ + "movq 48(%%rax), %%r9\n\t" \ + "movq 40(%%rax), %%r8\n\t" \ + "movq 32(%%rax), %%rcx\n\t" \ + "movq 24(%%rax), %%rdx\n\t" \ + "movq 16(%%rax), %%rsi\n\t" \ + "movq 8(%%rax), %%rdi\n\t" \ + "movq (%%rax), %%rax\n\t" /* target->%rax */ \ + VALGRIND_CALL_NOREDIR_RAX \ + VALGRIND_RESTORE_STACK \ + VALGRIND_CFI_EPILOGUE \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_10W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[11]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + VALGRIND_ALIGN_STACK \ + "subq $128,%%rsp\n\t" \ + "pushq 80(%%rax)\n\t" \ + "pushq 72(%%rax)\n\t" \ + "pushq 64(%%rax)\n\t" \ + "pushq 56(%%rax)\n\t" \ + "movq 48(%%rax), %%r9\n\t" \ + "movq 40(%%rax), %%r8\n\t" \ + "movq 32(%%rax), %%rcx\n\t" \ + "movq 24(%%rax), %%rdx\n\t" \ + "movq 16(%%rax), %%rsi\n\t" \ + "movq 8(%%rax), %%rdi\n\t" \ + "movq (%%rax), %%rax\n\t" /* target->%rax */ \ + VALGRIND_CALL_NOREDIR_RAX \ + VALGRIND_RESTORE_STACK \ + VALGRIND_CFI_EPILOGUE \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_11W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10,arg11) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[12]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + _argvec[11] = (unsigned long)(arg11); \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + VALGRIND_ALIGN_STACK \ + "subq $136,%%rsp\n\t" \ + "pushq 88(%%rax)\n\t" \ + "pushq 80(%%rax)\n\t" \ + "pushq 72(%%rax)\n\t" \ + "pushq 64(%%rax)\n\t" \ + "pushq 56(%%rax)\n\t" \ + "movq 48(%%rax), %%r9\n\t" \ + "movq 40(%%rax), %%r8\n\t" \ + "movq 32(%%rax), %%rcx\n\t" \ + "movq 24(%%rax), %%rdx\n\t" \ + "movq 16(%%rax), %%rsi\n\t" \ + "movq 8(%%rax), %%rdi\n\t" \ + "movq (%%rax), %%rax\n\t" /* target->%rax */ \ + VALGRIND_CALL_NOREDIR_RAX \ + VALGRIND_RESTORE_STACK \ + VALGRIND_CFI_EPILOGUE \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_12W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10,arg11,arg12) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[13]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + _argvec[11] = (unsigned long)(arg11); \ + _argvec[12] = (unsigned long)(arg12); \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + VALGRIND_ALIGN_STACK \ + "subq $128,%%rsp\n\t" \ + "pushq 96(%%rax)\n\t" \ + "pushq 88(%%rax)\n\t" \ + "pushq 80(%%rax)\n\t" \ + "pushq 72(%%rax)\n\t" \ + "pushq 64(%%rax)\n\t" \ + "pushq 56(%%rax)\n\t" \ + "movq 48(%%rax), %%r9\n\t" \ + "movq 40(%%rax), %%r8\n\t" \ + "movq 32(%%rax), %%rcx\n\t" \ + "movq 24(%%rax), %%rdx\n\t" \ + "movq 16(%%rax), %%rsi\n\t" \ + "movq 8(%%rax), %%rdi\n\t" \ + "movq (%%rax), %%rax\n\t" /* target->%rax */ \ + VALGRIND_CALL_NOREDIR_RAX \ + VALGRIND_RESTORE_STACK \ + VALGRIND_CFI_EPILOGUE \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#endif /* PLAT_amd64_linux || PLAT_amd64_darwin || PLAT_amd64_solaris */ + +/* ------------------------ ppc32-linux ------------------------ */ + +#if defined(PLAT_ppc32_linux) + +/* This is useful for finding out about the on-stack stuff: + + extern int f9 ( int,int,int,int,int,int,int,int,int ); + extern int f10 ( int,int,int,int,int,int,int,int,int,int ); + extern int f11 ( int,int,int,int,int,int,int,int,int,int,int ); + extern int f12 ( int,int,int,int,int,int,int,int,int,int,int,int ); + + int g9 ( void ) { + return f9(11,22,33,44,55,66,77,88,99); + } + int g10 ( void ) { + return f10(11,22,33,44,55,66,77,88,99,110); + } + int g11 ( void ) { + return f11(11,22,33,44,55,66,77,88,99,110,121); + } + int g12 ( void ) { + return f12(11,22,33,44,55,66,77,88,99,110,121,132); + } +*/ + +/* ARGREGS: r3 r4 r5 r6 r7 r8 r9 r10 (the rest on stack somewhere) */ + +/* These regs are trashed by the hidden call. */ +#define __CALLER_SAVED_REGS \ + "lr", "ctr", "xer", \ + "cr0", "cr1", "cr2", "cr3", "cr4", "cr5", "cr6", "cr7", \ + "r0", "r2", "r3", "r4", "r5", "r6", "r7", "r8", "r9", "r10", \ + "r11", "r12", "r13" + +/* Macros to save and align the stack before making a function + call and restore it afterwards as gcc may not keep the stack + pointer aligned if it doesn't realise calls are being made + to other functions. */ + +#define VALGRIND_ALIGN_STACK \ + "mr 28,1\n\t" \ + "rlwinm 1,1,0,0,27\n\t" +#define VALGRIND_RESTORE_STACK \ + "mr 1,28\n\t" + +/* These CALL_FN_ macros assume that on ppc32-linux, + sizeof(unsigned long) == 4. */ + +#define CALL_FN_W_v(lval, orig) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[1]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "lwz 11,0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + VALGRIND_RESTORE_STACK \ + "mr %0,3" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_W(lval, orig, arg1) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[2]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "lwz 3,4(11)\n\t" /* arg1->r3 */ \ + "lwz 11,0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + VALGRIND_RESTORE_STACK \ + "mr %0,3" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WW(lval, orig, arg1,arg2) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "lwz 3,4(11)\n\t" /* arg1->r3 */ \ + "lwz 4,8(11)\n\t" \ + "lwz 11,0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + VALGRIND_RESTORE_STACK \ + "mr %0,3" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWW(lval, orig, arg1,arg2,arg3) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[4]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "lwz 3,4(11)\n\t" /* arg1->r3 */ \ + "lwz 4,8(11)\n\t" \ + "lwz 5,12(11)\n\t" \ + "lwz 11,0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + VALGRIND_RESTORE_STACK \ + "mr %0,3" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWWW(lval, orig, arg1,arg2,arg3,arg4) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[5]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "lwz 3,4(11)\n\t" /* arg1->r3 */ \ + "lwz 4,8(11)\n\t" \ + "lwz 5,12(11)\n\t" \ + "lwz 6,16(11)\n\t" /* arg4->r6 */ \ + "lwz 11,0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + VALGRIND_RESTORE_STACK \ + "mr %0,3" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_5W(lval, orig, arg1,arg2,arg3,arg4,arg5) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[6]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + _argvec[5] = (unsigned long)arg5; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "lwz 3,4(11)\n\t" /* arg1->r3 */ \ + "lwz 4,8(11)\n\t" \ + "lwz 5,12(11)\n\t" \ + "lwz 6,16(11)\n\t" /* arg4->r6 */ \ + "lwz 7,20(11)\n\t" \ + "lwz 11,0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + VALGRIND_RESTORE_STACK \ + "mr %0,3" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_6W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[7]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + _argvec[5] = (unsigned long)arg5; \ + _argvec[6] = (unsigned long)arg6; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "lwz 3,4(11)\n\t" /* arg1->r3 */ \ + "lwz 4,8(11)\n\t" \ + "lwz 5,12(11)\n\t" \ + "lwz 6,16(11)\n\t" /* arg4->r6 */ \ + "lwz 7,20(11)\n\t" \ + "lwz 8,24(11)\n\t" \ + "lwz 11,0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + VALGRIND_RESTORE_STACK \ + "mr %0,3" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_7W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[8]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + _argvec[5] = (unsigned long)arg5; \ + _argvec[6] = (unsigned long)arg6; \ + _argvec[7] = (unsigned long)arg7; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "lwz 3,4(11)\n\t" /* arg1->r3 */ \ + "lwz 4,8(11)\n\t" \ + "lwz 5,12(11)\n\t" \ + "lwz 6,16(11)\n\t" /* arg4->r6 */ \ + "lwz 7,20(11)\n\t" \ + "lwz 8,24(11)\n\t" \ + "lwz 9,28(11)\n\t" \ + "lwz 11,0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + VALGRIND_RESTORE_STACK \ + "mr %0,3" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_8W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[9]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + _argvec[5] = (unsigned long)arg5; \ + _argvec[6] = (unsigned long)arg6; \ + _argvec[7] = (unsigned long)arg7; \ + _argvec[8] = (unsigned long)arg8; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "lwz 3,4(11)\n\t" /* arg1->r3 */ \ + "lwz 4,8(11)\n\t" \ + "lwz 5,12(11)\n\t" \ + "lwz 6,16(11)\n\t" /* arg4->r6 */ \ + "lwz 7,20(11)\n\t" \ + "lwz 8,24(11)\n\t" \ + "lwz 9,28(11)\n\t" \ + "lwz 10,32(11)\n\t" /* arg8->r10 */ \ + "lwz 11,0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + VALGRIND_RESTORE_STACK \ + "mr %0,3" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_9W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[10]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + _argvec[5] = (unsigned long)arg5; \ + _argvec[6] = (unsigned long)arg6; \ + _argvec[7] = (unsigned long)arg7; \ + _argvec[8] = (unsigned long)arg8; \ + _argvec[9] = (unsigned long)arg9; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "addi 1,1,-16\n\t" \ + /* arg9 */ \ + "lwz 3,36(11)\n\t" \ + "stw 3,8(1)\n\t" \ + /* args1-8 */ \ + "lwz 3,4(11)\n\t" /* arg1->r3 */ \ + "lwz 4,8(11)\n\t" \ + "lwz 5,12(11)\n\t" \ + "lwz 6,16(11)\n\t" /* arg4->r6 */ \ + "lwz 7,20(11)\n\t" \ + "lwz 8,24(11)\n\t" \ + "lwz 9,28(11)\n\t" \ + "lwz 10,32(11)\n\t" /* arg8->r10 */ \ + "lwz 11,0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + VALGRIND_RESTORE_STACK \ + "mr %0,3" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_10W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[11]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + _argvec[5] = (unsigned long)arg5; \ + _argvec[6] = (unsigned long)arg6; \ + _argvec[7] = (unsigned long)arg7; \ + _argvec[8] = (unsigned long)arg8; \ + _argvec[9] = (unsigned long)arg9; \ + _argvec[10] = (unsigned long)arg10; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "addi 1,1,-16\n\t" \ + /* arg10 */ \ + "lwz 3,40(11)\n\t" \ + "stw 3,12(1)\n\t" \ + /* arg9 */ \ + "lwz 3,36(11)\n\t" \ + "stw 3,8(1)\n\t" \ + /* args1-8 */ \ + "lwz 3,4(11)\n\t" /* arg1->r3 */ \ + "lwz 4,8(11)\n\t" \ + "lwz 5,12(11)\n\t" \ + "lwz 6,16(11)\n\t" /* arg4->r6 */ \ + "lwz 7,20(11)\n\t" \ + "lwz 8,24(11)\n\t" \ + "lwz 9,28(11)\n\t" \ + "lwz 10,32(11)\n\t" /* arg8->r10 */ \ + "lwz 11,0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + VALGRIND_RESTORE_STACK \ + "mr %0,3" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_11W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10,arg11) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[12]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + _argvec[5] = (unsigned long)arg5; \ + _argvec[6] = (unsigned long)arg6; \ + _argvec[7] = (unsigned long)arg7; \ + _argvec[8] = (unsigned long)arg8; \ + _argvec[9] = (unsigned long)arg9; \ + _argvec[10] = (unsigned long)arg10; \ + _argvec[11] = (unsigned long)arg11; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "addi 1,1,-32\n\t" \ + /* arg11 */ \ + "lwz 3,44(11)\n\t" \ + "stw 3,16(1)\n\t" \ + /* arg10 */ \ + "lwz 3,40(11)\n\t" \ + "stw 3,12(1)\n\t" \ + /* arg9 */ \ + "lwz 3,36(11)\n\t" \ + "stw 3,8(1)\n\t" \ + /* args1-8 */ \ + "lwz 3,4(11)\n\t" /* arg1->r3 */ \ + "lwz 4,8(11)\n\t" \ + "lwz 5,12(11)\n\t" \ + "lwz 6,16(11)\n\t" /* arg4->r6 */ \ + "lwz 7,20(11)\n\t" \ + "lwz 8,24(11)\n\t" \ + "lwz 9,28(11)\n\t" \ + "lwz 10,32(11)\n\t" /* arg8->r10 */ \ + "lwz 11,0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + VALGRIND_RESTORE_STACK \ + "mr %0,3" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_12W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10,arg11,arg12) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[13]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + _argvec[5] = (unsigned long)arg5; \ + _argvec[6] = (unsigned long)arg6; \ + _argvec[7] = (unsigned long)arg7; \ + _argvec[8] = (unsigned long)arg8; \ + _argvec[9] = (unsigned long)arg9; \ + _argvec[10] = (unsigned long)arg10; \ + _argvec[11] = (unsigned long)arg11; \ + _argvec[12] = (unsigned long)arg12; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "addi 1,1,-32\n\t" \ + /* arg12 */ \ + "lwz 3,48(11)\n\t" \ + "stw 3,20(1)\n\t" \ + /* arg11 */ \ + "lwz 3,44(11)\n\t" \ + "stw 3,16(1)\n\t" \ + /* arg10 */ \ + "lwz 3,40(11)\n\t" \ + "stw 3,12(1)\n\t" \ + /* arg9 */ \ + "lwz 3,36(11)\n\t" \ + "stw 3,8(1)\n\t" \ + /* args1-8 */ \ + "lwz 3,4(11)\n\t" /* arg1->r3 */ \ + "lwz 4,8(11)\n\t" \ + "lwz 5,12(11)\n\t" \ + "lwz 6,16(11)\n\t" /* arg4->r6 */ \ + "lwz 7,20(11)\n\t" \ + "lwz 8,24(11)\n\t" \ + "lwz 9,28(11)\n\t" \ + "lwz 10,32(11)\n\t" /* arg8->r10 */ \ + "lwz 11,0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + VALGRIND_RESTORE_STACK \ + "mr %0,3" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#endif /* PLAT_ppc32_linux */ + +/* ------------------------ ppc64-linux ------------------------ */ + +#if defined(PLAT_ppc64be_linux) + +/* ARGREGS: r3 r4 r5 r6 r7 r8 r9 r10 (the rest on stack somewhere) */ + +/* These regs are trashed by the hidden call. */ +#define __CALLER_SAVED_REGS \ + "lr", "ctr", "xer", \ + "cr0", "cr1", "cr2", "cr3", "cr4", "cr5", "cr6", "cr7", \ + "r0", "r3", "r4", "r5", "r6", "r7", "r8", "r9", "r10", \ + "r11", "r12", "r13" + +/* Macros to save and align the stack before making a function + call and restore it afterwards as gcc may not keep the stack + pointer aligned if it doesn't realise calls are being made + to other functions. */ + +#define VALGRIND_ALIGN_STACK \ + "mr 28,1\n\t" \ + "rldicr 1,1,0,59\n\t" +#define VALGRIND_RESTORE_STACK \ + "mr 1,28\n\t" + +/* These CALL_FN_ macros assume that on ppc64-linux, sizeof(unsigned + long) == 8. */ + +#define CALL_FN_W_v(lval, orig) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+0]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "std 2,-16(11)\n\t" /* save tocptr */ \ + "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ + "ld 11, 0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + "mr 11,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(11)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_W(lval, orig, arg1) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+1]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "std 2,-16(11)\n\t" /* save tocptr */ \ + "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ + "ld 3, 8(11)\n\t" /* arg1->r3 */ \ + "ld 11, 0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + "mr 11,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(11)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WW(lval, orig, arg1,arg2) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+2]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "std 2,-16(11)\n\t" /* save tocptr */ \ + "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ + "ld 3, 8(11)\n\t" /* arg1->r3 */ \ + "ld 4, 16(11)\n\t" /* arg2->r4 */ \ + "ld 11, 0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + "mr 11,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(11)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWW(lval, orig, arg1,arg2,arg3) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+3]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "std 2,-16(11)\n\t" /* save tocptr */ \ + "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ + "ld 3, 8(11)\n\t" /* arg1->r3 */ \ + "ld 4, 16(11)\n\t" /* arg2->r4 */ \ + "ld 5, 24(11)\n\t" /* arg3->r5 */ \ + "ld 11, 0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + "mr 11,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(11)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWWW(lval, orig, arg1,arg2,arg3,arg4) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+4]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "std 2,-16(11)\n\t" /* save tocptr */ \ + "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ + "ld 3, 8(11)\n\t" /* arg1->r3 */ \ + "ld 4, 16(11)\n\t" /* arg2->r4 */ \ + "ld 5, 24(11)\n\t" /* arg3->r5 */ \ + "ld 6, 32(11)\n\t" /* arg4->r6 */ \ + "ld 11, 0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + "mr 11,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(11)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_5W(lval, orig, arg1,arg2,arg3,arg4,arg5) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+5]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + _argvec[2+5] = (unsigned long)arg5; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "std 2,-16(11)\n\t" /* save tocptr */ \ + "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ + "ld 3, 8(11)\n\t" /* arg1->r3 */ \ + "ld 4, 16(11)\n\t" /* arg2->r4 */ \ + "ld 5, 24(11)\n\t" /* arg3->r5 */ \ + "ld 6, 32(11)\n\t" /* arg4->r6 */ \ + "ld 7, 40(11)\n\t" /* arg5->r7 */ \ + "ld 11, 0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + "mr 11,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(11)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_6W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+6]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + _argvec[2+5] = (unsigned long)arg5; \ + _argvec[2+6] = (unsigned long)arg6; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "std 2,-16(11)\n\t" /* save tocptr */ \ + "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ + "ld 3, 8(11)\n\t" /* arg1->r3 */ \ + "ld 4, 16(11)\n\t" /* arg2->r4 */ \ + "ld 5, 24(11)\n\t" /* arg3->r5 */ \ + "ld 6, 32(11)\n\t" /* arg4->r6 */ \ + "ld 7, 40(11)\n\t" /* arg5->r7 */ \ + "ld 8, 48(11)\n\t" /* arg6->r8 */ \ + "ld 11, 0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + "mr 11,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(11)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_7W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+7]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + _argvec[2+5] = (unsigned long)arg5; \ + _argvec[2+6] = (unsigned long)arg6; \ + _argvec[2+7] = (unsigned long)arg7; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "std 2,-16(11)\n\t" /* save tocptr */ \ + "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ + "ld 3, 8(11)\n\t" /* arg1->r3 */ \ + "ld 4, 16(11)\n\t" /* arg2->r4 */ \ + "ld 5, 24(11)\n\t" /* arg3->r5 */ \ + "ld 6, 32(11)\n\t" /* arg4->r6 */ \ + "ld 7, 40(11)\n\t" /* arg5->r7 */ \ + "ld 8, 48(11)\n\t" /* arg6->r8 */ \ + "ld 9, 56(11)\n\t" /* arg7->r9 */ \ + "ld 11, 0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + "mr 11,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(11)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_8W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+8]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + _argvec[2+5] = (unsigned long)arg5; \ + _argvec[2+6] = (unsigned long)arg6; \ + _argvec[2+7] = (unsigned long)arg7; \ + _argvec[2+8] = (unsigned long)arg8; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "std 2,-16(11)\n\t" /* save tocptr */ \ + "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ + "ld 3, 8(11)\n\t" /* arg1->r3 */ \ + "ld 4, 16(11)\n\t" /* arg2->r4 */ \ + "ld 5, 24(11)\n\t" /* arg3->r5 */ \ + "ld 6, 32(11)\n\t" /* arg4->r6 */ \ + "ld 7, 40(11)\n\t" /* arg5->r7 */ \ + "ld 8, 48(11)\n\t" /* arg6->r8 */ \ + "ld 9, 56(11)\n\t" /* arg7->r9 */ \ + "ld 10, 64(11)\n\t" /* arg8->r10 */ \ + "ld 11, 0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + "mr 11,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(11)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_9W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+9]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + _argvec[2+5] = (unsigned long)arg5; \ + _argvec[2+6] = (unsigned long)arg6; \ + _argvec[2+7] = (unsigned long)arg7; \ + _argvec[2+8] = (unsigned long)arg8; \ + _argvec[2+9] = (unsigned long)arg9; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "std 2,-16(11)\n\t" /* save tocptr */ \ + "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ + "addi 1,1,-128\n\t" /* expand stack frame */ \ + /* arg9 */ \ + "ld 3,72(11)\n\t" \ + "std 3,112(1)\n\t" \ + /* args1-8 */ \ + "ld 3, 8(11)\n\t" /* arg1->r3 */ \ + "ld 4, 16(11)\n\t" /* arg2->r4 */ \ + "ld 5, 24(11)\n\t" /* arg3->r5 */ \ + "ld 6, 32(11)\n\t" /* arg4->r6 */ \ + "ld 7, 40(11)\n\t" /* arg5->r7 */ \ + "ld 8, 48(11)\n\t" /* arg6->r8 */ \ + "ld 9, 56(11)\n\t" /* arg7->r9 */ \ + "ld 10, 64(11)\n\t" /* arg8->r10 */ \ + "ld 11, 0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + "mr 11,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(11)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_10W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+10]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + _argvec[2+5] = (unsigned long)arg5; \ + _argvec[2+6] = (unsigned long)arg6; \ + _argvec[2+7] = (unsigned long)arg7; \ + _argvec[2+8] = (unsigned long)arg8; \ + _argvec[2+9] = (unsigned long)arg9; \ + _argvec[2+10] = (unsigned long)arg10; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "std 2,-16(11)\n\t" /* save tocptr */ \ + "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ + "addi 1,1,-128\n\t" /* expand stack frame */ \ + /* arg10 */ \ + "ld 3,80(11)\n\t" \ + "std 3,120(1)\n\t" \ + /* arg9 */ \ + "ld 3,72(11)\n\t" \ + "std 3,112(1)\n\t" \ + /* args1-8 */ \ + "ld 3, 8(11)\n\t" /* arg1->r3 */ \ + "ld 4, 16(11)\n\t" /* arg2->r4 */ \ + "ld 5, 24(11)\n\t" /* arg3->r5 */ \ + "ld 6, 32(11)\n\t" /* arg4->r6 */ \ + "ld 7, 40(11)\n\t" /* arg5->r7 */ \ + "ld 8, 48(11)\n\t" /* arg6->r8 */ \ + "ld 9, 56(11)\n\t" /* arg7->r9 */ \ + "ld 10, 64(11)\n\t" /* arg8->r10 */ \ + "ld 11, 0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + "mr 11,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(11)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_11W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10,arg11) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+11]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + _argvec[2+5] = (unsigned long)arg5; \ + _argvec[2+6] = (unsigned long)arg6; \ + _argvec[2+7] = (unsigned long)arg7; \ + _argvec[2+8] = (unsigned long)arg8; \ + _argvec[2+9] = (unsigned long)arg9; \ + _argvec[2+10] = (unsigned long)arg10; \ + _argvec[2+11] = (unsigned long)arg11; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "std 2,-16(11)\n\t" /* save tocptr */ \ + "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ + "addi 1,1,-144\n\t" /* expand stack frame */ \ + /* arg11 */ \ + "ld 3,88(11)\n\t" \ + "std 3,128(1)\n\t" \ + /* arg10 */ \ + "ld 3,80(11)\n\t" \ + "std 3,120(1)\n\t" \ + /* arg9 */ \ + "ld 3,72(11)\n\t" \ + "std 3,112(1)\n\t" \ + /* args1-8 */ \ + "ld 3, 8(11)\n\t" /* arg1->r3 */ \ + "ld 4, 16(11)\n\t" /* arg2->r4 */ \ + "ld 5, 24(11)\n\t" /* arg3->r5 */ \ + "ld 6, 32(11)\n\t" /* arg4->r6 */ \ + "ld 7, 40(11)\n\t" /* arg5->r7 */ \ + "ld 8, 48(11)\n\t" /* arg6->r8 */ \ + "ld 9, 56(11)\n\t" /* arg7->r9 */ \ + "ld 10, 64(11)\n\t" /* arg8->r10 */ \ + "ld 11, 0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + "mr 11,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(11)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_12W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10,arg11,arg12) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+12]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + _argvec[2+5] = (unsigned long)arg5; \ + _argvec[2+6] = (unsigned long)arg6; \ + _argvec[2+7] = (unsigned long)arg7; \ + _argvec[2+8] = (unsigned long)arg8; \ + _argvec[2+9] = (unsigned long)arg9; \ + _argvec[2+10] = (unsigned long)arg10; \ + _argvec[2+11] = (unsigned long)arg11; \ + _argvec[2+12] = (unsigned long)arg12; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "std 2,-16(11)\n\t" /* save tocptr */ \ + "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ + "addi 1,1,-144\n\t" /* expand stack frame */ \ + /* arg12 */ \ + "ld 3,96(11)\n\t" \ + "std 3,136(1)\n\t" \ + /* arg11 */ \ + "ld 3,88(11)\n\t" \ + "std 3,128(1)\n\t" \ + /* arg10 */ \ + "ld 3,80(11)\n\t" \ + "std 3,120(1)\n\t" \ + /* arg9 */ \ + "ld 3,72(11)\n\t" \ + "std 3,112(1)\n\t" \ + /* args1-8 */ \ + "ld 3, 8(11)\n\t" /* arg1->r3 */ \ + "ld 4, 16(11)\n\t" /* arg2->r4 */ \ + "ld 5, 24(11)\n\t" /* arg3->r5 */ \ + "ld 6, 32(11)\n\t" /* arg4->r6 */ \ + "ld 7, 40(11)\n\t" /* arg5->r7 */ \ + "ld 8, 48(11)\n\t" /* arg6->r8 */ \ + "ld 9, 56(11)\n\t" /* arg7->r9 */ \ + "ld 10, 64(11)\n\t" /* arg8->r10 */ \ + "ld 11, 0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + "mr 11,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(11)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#endif /* PLAT_ppc64be_linux */ + +/* ------------------------- ppc64le-linux ----------------------- */ +#if defined(PLAT_ppc64le_linux) + +/* ARGREGS: r3 r4 r5 r6 r7 r8 r9 r10 (the rest on stack somewhere) */ + +/* These regs are trashed by the hidden call. */ +#define __CALLER_SAVED_REGS \ + "lr", "ctr", "xer", \ + "cr0", "cr1", "cr2", "cr3", "cr4", "cr5", "cr6", "cr7", \ + "r0", "r3", "r4", "r5", "r6", "r7", "r8", "r9", "r10", \ + "r11", "r12", "r13" + +/* Macros to save and align the stack before making a function + call and restore it afterwards as gcc may not keep the stack + pointer aligned if it doesn't realise calls are being made + to other functions. */ + +#define VALGRIND_ALIGN_STACK \ + "mr 28,1\n\t" \ + "rldicr 1,1,0,59\n\t" +#define VALGRIND_RESTORE_STACK \ + "mr 1,28\n\t" + +/* These CALL_FN_ macros assume that on ppc64-linux, sizeof(unsigned + long) == 8. */ + +#define CALL_FN_W_v(lval, orig) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+0]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 12,%1\n\t" \ + "std 2,-16(12)\n\t" /* save tocptr */ \ + "ld 2,-8(12)\n\t" /* use nraddr's tocptr */ \ + "ld 12, 0(12)\n\t" /* target->r12 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12 \ + "mr 12,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(12)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_W(lval, orig, arg1) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+1]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 12,%1\n\t" \ + "std 2,-16(12)\n\t" /* save tocptr */ \ + "ld 2,-8(12)\n\t" /* use nraddr's tocptr */ \ + "ld 3, 8(12)\n\t" /* arg1->r3 */ \ + "ld 12, 0(12)\n\t" /* target->r12 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12 \ + "mr 12,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(12)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WW(lval, orig, arg1,arg2) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+2]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 12,%1\n\t" \ + "std 2,-16(12)\n\t" /* save tocptr */ \ + "ld 2,-8(12)\n\t" /* use nraddr's tocptr */ \ + "ld 3, 8(12)\n\t" /* arg1->r3 */ \ + "ld 4, 16(12)\n\t" /* arg2->r4 */ \ + "ld 12, 0(12)\n\t" /* target->r12 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12 \ + "mr 12,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(12)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWW(lval, orig, arg1,arg2,arg3) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+3]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 12,%1\n\t" \ + "std 2,-16(12)\n\t" /* save tocptr */ \ + "ld 2,-8(12)\n\t" /* use nraddr's tocptr */ \ + "ld 3, 8(12)\n\t" /* arg1->r3 */ \ + "ld 4, 16(12)\n\t" /* arg2->r4 */ \ + "ld 5, 24(12)\n\t" /* arg3->r5 */ \ + "ld 12, 0(12)\n\t" /* target->r12 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12 \ + "mr 12,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(12)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWWW(lval, orig, arg1,arg2,arg3,arg4) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+4]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 12,%1\n\t" \ + "std 2,-16(12)\n\t" /* save tocptr */ \ + "ld 2,-8(12)\n\t" /* use nraddr's tocptr */ \ + "ld 3, 8(12)\n\t" /* arg1->r3 */ \ + "ld 4, 16(12)\n\t" /* arg2->r4 */ \ + "ld 5, 24(12)\n\t" /* arg3->r5 */ \ + "ld 6, 32(12)\n\t" /* arg4->r6 */ \ + "ld 12, 0(12)\n\t" /* target->r12 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12 \ + "mr 12,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(12)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_5W(lval, orig, arg1,arg2,arg3,arg4,arg5) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+5]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + _argvec[2+5] = (unsigned long)arg5; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 12,%1\n\t" \ + "std 2,-16(12)\n\t" /* save tocptr */ \ + "ld 2,-8(12)\n\t" /* use nraddr's tocptr */ \ + "ld 3, 8(12)\n\t" /* arg1->r3 */ \ + "ld 4, 16(12)\n\t" /* arg2->r4 */ \ + "ld 5, 24(12)\n\t" /* arg3->r5 */ \ + "ld 6, 32(12)\n\t" /* arg4->r6 */ \ + "ld 7, 40(12)\n\t" /* arg5->r7 */ \ + "ld 12, 0(12)\n\t" /* target->r12 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12 \ + "mr 12,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(12)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_6W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+6]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + _argvec[2+5] = (unsigned long)arg5; \ + _argvec[2+6] = (unsigned long)arg6; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 12,%1\n\t" \ + "std 2,-16(12)\n\t" /* save tocptr */ \ + "ld 2,-8(12)\n\t" /* use nraddr's tocptr */ \ + "ld 3, 8(12)\n\t" /* arg1->r3 */ \ + "ld 4, 16(12)\n\t" /* arg2->r4 */ \ + "ld 5, 24(12)\n\t" /* arg3->r5 */ \ + "ld 6, 32(12)\n\t" /* arg4->r6 */ \ + "ld 7, 40(12)\n\t" /* arg5->r7 */ \ + "ld 8, 48(12)\n\t" /* arg6->r8 */ \ + "ld 12, 0(12)\n\t" /* target->r12 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12 \ + "mr 12,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(12)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_7W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+7]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + _argvec[2+5] = (unsigned long)arg5; \ + _argvec[2+6] = (unsigned long)arg6; \ + _argvec[2+7] = (unsigned long)arg7; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 12,%1\n\t" \ + "std 2,-16(12)\n\t" /* save tocptr */ \ + "ld 2,-8(12)\n\t" /* use nraddr's tocptr */ \ + "ld 3, 8(12)\n\t" /* arg1->r3 */ \ + "ld 4, 16(12)\n\t" /* arg2->r4 */ \ + "ld 5, 24(12)\n\t" /* arg3->r5 */ \ + "ld 6, 32(12)\n\t" /* arg4->r6 */ \ + "ld 7, 40(12)\n\t" /* arg5->r7 */ \ + "ld 8, 48(12)\n\t" /* arg6->r8 */ \ + "ld 9, 56(12)\n\t" /* arg7->r9 */ \ + "ld 12, 0(12)\n\t" /* target->r12 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12 \ + "mr 12,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(12)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_8W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+8]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + _argvec[2+5] = (unsigned long)arg5; \ + _argvec[2+6] = (unsigned long)arg6; \ + _argvec[2+7] = (unsigned long)arg7; \ + _argvec[2+8] = (unsigned long)arg8; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 12,%1\n\t" \ + "std 2,-16(12)\n\t" /* save tocptr */ \ + "ld 2,-8(12)\n\t" /* use nraddr's tocptr */ \ + "ld 3, 8(12)\n\t" /* arg1->r3 */ \ + "ld 4, 16(12)\n\t" /* arg2->r4 */ \ + "ld 5, 24(12)\n\t" /* arg3->r5 */ \ + "ld 6, 32(12)\n\t" /* arg4->r6 */ \ + "ld 7, 40(12)\n\t" /* arg5->r7 */ \ + "ld 8, 48(12)\n\t" /* arg6->r8 */ \ + "ld 9, 56(12)\n\t" /* arg7->r9 */ \ + "ld 10, 64(12)\n\t" /* arg8->r10 */ \ + "ld 12, 0(12)\n\t" /* target->r12 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12 \ + "mr 12,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(12)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_9W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+9]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + _argvec[2+5] = (unsigned long)arg5; \ + _argvec[2+6] = (unsigned long)arg6; \ + _argvec[2+7] = (unsigned long)arg7; \ + _argvec[2+8] = (unsigned long)arg8; \ + _argvec[2+9] = (unsigned long)arg9; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 12,%1\n\t" \ + "std 2,-16(12)\n\t" /* save tocptr */ \ + "ld 2,-8(12)\n\t" /* use nraddr's tocptr */ \ + "addi 1,1,-128\n\t" /* expand stack frame */ \ + /* arg9 */ \ + "ld 3,72(12)\n\t" \ + "std 3,96(1)\n\t" \ + /* args1-8 */ \ + "ld 3, 8(12)\n\t" /* arg1->r3 */ \ + "ld 4, 16(12)\n\t" /* arg2->r4 */ \ + "ld 5, 24(12)\n\t" /* arg3->r5 */ \ + "ld 6, 32(12)\n\t" /* arg4->r6 */ \ + "ld 7, 40(12)\n\t" /* arg5->r7 */ \ + "ld 8, 48(12)\n\t" /* arg6->r8 */ \ + "ld 9, 56(12)\n\t" /* arg7->r9 */ \ + "ld 10, 64(12)\n\t" /* arg8->r10 */ \ + "ld 12, 0(12)\n\t" /* target->r12 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12 \ + "mr 12,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(12)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_10W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+10]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + _argvec[2+5] = (unsigned long)arg5; \ + _argvec[2+6] = (unsigned long)arg6; \ + _argvec[2+7] = (unsigned long)arg7; \ + _argvec[2+8] = (unsigned long)arg8; \ + _argvec[2+9] = (unsigned long)arg9; \ + _argvec[2+10] = (unsigned long)arg10; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 12,%1\n\t" \ + "std 2,-16(12)\n\t" /* save tocptr */ \ + "ld 2,-8(12)\n\t" /* use nraddr's tocptr */ \ + "addi 1,1,-128\n\t" /* expand stack frame */ \ + /* arg10 */ \ + "ld 3,80(12)\n\t" \ + "std 3,104(1)\n\t" \ + /* arg9 */ \ + "ld 3,72(12)\n\t" \ + "std 3,96(1)\n\t" \ + /* args1-8 */ \ + "ld 3, 8(12)\n\t" /* arg1->r3 */ \ + "ld 4, 16(12)\n\t" /* arg2->r4 */ \ + "ld 5, 24(12)\n\t" /* arg3->r5 */ \ + "ld 6, 32(12)\n\t" /* arg4->r6 */ \ + "ld 7, 40(12)\n\t" /* arg5->r7 */ \ + "ld 8, 48(12)\n\t" /* arg6->r8 */ \ + "ld 9, 56(12)\n\t" /* arg7->r9 */ \ + "ld 10, 64(12)\n\t" /* arg8->r10 */ \ + "ld 12, 0(12)\n\t" /* target->r12 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12 \ + "mr 12,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(12)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_11W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10,arg11) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+11]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + _argvec[2+5] = (unsigned long)arg5; \ + _argvec[2+6] = (unsigned long)arg6; \ + _argvec[2+7] = (unsigned long)arg7; \ + _argvec[2+8] = (unsigned long)arg8; \ + _argvec[2+9] = (unsigned long)arg9; \ + _argvec[2+10] = (unsigned long)arg10; \ + _argvec[2+11] = (unsigned long)arg11; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 12,%1\n\t" \ + "std 2,-16(12)\n\t" /* save tocptr */ \ + "ld 2,-8(12)\n\t" /* use nraddr's tocptr */ \ + "addi 1,1,-144\n\t" /* expand stack frame */ \ + /* arg11 */ \ + "ld 3,88(12)\n\t" \ + "std 3,112(1)\n\t" \ + /* arg10 */ \ + "ld 3,80(12)\n\t" \ + "std 3,104(1)\n\t" \ + /* arg9 */ \ + "ld 3,72(12)\n\t" \ + "std 3,96(1)\n\t" \ + /* args1-8 */ \ + "ld 3, 8(12)\n\t" /* arg1->r3 */ \ + "ld 4, 16(12)\n\t" /* arg2->r4 */ \ + "ld 5, 24(12)\n\t" /* arg3->r5 */ \ + "ld 6, 32(12)\n\t" /* arg4->r6 */ \ + "ld 7, 40(12)\n\t" /* arg5->r7 */ \ + "ld 8, 48(12)\n\t" /* arg6->r8 */ \ + "ld 9, 56(12)\n\t" /* arg7->r9 */ \ + "ld 10, 64(12)\n\t" /* arg8->r10 */ \ + "ld 12, 0(12)\n\t" /* target->r12 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12 \ + "mr 12,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(12)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_12W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10,arg11,arg12) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+12]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + _argvec[2+5] = (unsigned long)arg5; \ + _argvec[2+6] = (unsigned long)arg6; \ + _argvec[2+7] = (unsigned long)arg7; \ + _argvec[2+8] = (unsigned long)arg8; \ + _argvec[2+9] = (unsigned long)arg9; \ + _argvec[2+10] = (unsigned long)arg10; \ + _argvec[2+11] = (unsigned long)arg11; \ + _argvec[2+12] = (unsigned long)arg12; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 12,%1\n\t" \ + "std 2,-16(12)\n\t" /* save tocptr */ \ + "ld 2,-8(12)\n\t" /* use nraddr's tocptr */ \ + "addi 1,1,-144\n\t" /* expand stack frame */ \ + /* arg12 */ \ + "ld 3,96(12)\n\t" \ + "std 3,120(1)\n\t" \ + /* arg11 */ \ + "ld 3,88(12)\n\t" \ + "std 3,112(1)\n\t" \ + /* arg10 */ \ + "ld 3,80(12)\n\t" \ + "std 3,104(1)\n\t" \ + /* arg9 */ \ + "ld 3,72(12)\n\t" \ + "std 3,96(1)\n\t" \ + /* args1-8 */ \ + "ld 3, 8(12)\n\t" /* arg1->r3 */ \ + "ld 4, 16(12)\n\t" /* arg2->r4 */ \ + "ld 5, 24(12)\n\t" /* arg3->r5 */ \ + "ld 6, 32(12)\n\t" /* arg4->r6 */ \ + "ld 7, 40(12)\n\t" /* arg5->r7 */ \ + "ld 8, 48(12)\n\t" /* arg6->r8 */ \ + "ld 9, 56(12)\n\t" /* arg7->r9 */ \ + "ld 10, 64(12)\n\t" /* arg8->r10 */ \ + "ld 12, 0(12)\n\t" /* target->r12 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12 \ + "mr 12,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(12)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#endif /* PLAT_ppc64le_linux */ + +/* ------------------------- arm-linux ------------------------- */ + +#if defined(PLAT_arm_linux) + +/* These regs are trashed by the hidden call. */ +#define __CALLER_SAVED_REGS "r0", "r1", "r2", "r3","r4", "r12", "r14" + +/* Macros to save and align the stack before making a function + call and restore it afterwards as gcc may not keep the stack + pointer aligned if it doesn't realise calls are being made + to other functions. */ + +/* This is a bit tricky. We store the original stack pointer in r10 + as it is callee-saves. gcc doesn't allow the use of r11 for some + reason. Also, we can't directly "bic" the stack pointer in thumb + mode since r13 isn't an allowed register number in that context. + So use r4 as a temporary, since that is about to get trashed + anyway, just after each use of this macro. Side effect is we need + to be very careful about any future changes, since + VALGRIND_ALIGN_STACK simply assumes r4 is usable. */ +#define VALGRIND_ALIGN_STACK \ + "mov r10, sp\n\t" \ + "mov r4, sp\n\t" \ + "bic r4, r4, #7\n\t" \ + "mov sp, r4\n\t" +#define VALGRIND_RESTORE_STACK \ + "mov sp, r10\n\t" + +/* These CALL_FN_ macros assume that on arm-linux, sizeof(unsigned + long) == 4. */ + +#define CALL_FN_W_v(lval, orig) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[1]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr r4, [%1] \n\t" /* target->r4 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4 \ + VALGRIND_RESTORE_STACK \ + "mov %0, r0\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_W(lval, orig, arg1) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[2]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr r0, [%1, #4] \n\t" \ + "ldr r4, [%1] \n\t" /* target->r4 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4 \ + VALGRIND_RESTORE_STACK \ + "mov %0, r0\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WW(lval, orig, arg1,arg2) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr r0, [%1, #4] \n\t" \ + "ldr r1, [%1, #8] \n\t" \ + "ldr r4, [%1] \n\t" /* target->r4 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4 \ + VALGRIND_RESTORE_STACK \ + "mov %0, r0\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWW(lval, orig, arg1,arg2,arg3) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[4]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr r0, [%1, #4] \n\t" \ + "ldr r1, [%1, #8] \n\t" \ + "ldr r2, [%1, #12] \n\t" \ + "ldr r4, [%1] \n\t" /* target->r4 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4 \ + VALGRIND_RESTORE_STACK \ + "mov %0, r0\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWWW(lval, orig, arg1,arg2,arg3,arg4) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[5]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr r0, [%1, #4] \n\t" \ + "ldr r1, [%1, #8] \n\t" \ + "ldr r2, [%1, #12] \n\t" \ + "ldr r3, [%1, #16] \n\t" \ + "ldr r4, [%1] \n\t" /* target->r4 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4 \ + VALGRIND_RESTORE_STACK \ + "mov %0, r0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_5W(lval, orig, arg1,arg2,arg3,arg4,arg5) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[6]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "sub sp, sp, #4 \n\t" \ + "ldr r0, [%1, #20] \n\t" \ + "push {r0} \n\t" \ + "ldr r0, [%1, #4] \n\t" \ + "ldr r1, [%1, #8] \n\t" \ + "ldr r2, [%1, #12] \n\t" \ + "ldr r3, [%1, #16] \n\t" \ + "ldr r4, [%1] \n\t" /* target->r4 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4 \ + VALGRIND_RESTORE_STACK \ + "mov %0, r0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_6W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[7]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr r0, [%1, #20] \n\t" \ + "ldr r1, [%1, #24] \n\t" \ + "push {r0, r1} \n\t" \ + "ldr r0, [%1, #4] \n\t" \ + "ldr r1, [%1, #8] \n\t" \ + "ldr r2, [%1, #12] \n\t" \ + "ldr r3, [%1, #16] \n\t" \ + "ldr r4, [%1] \n\t" /* target->r4 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4 \ + VALGRIND_RESTORE_STACK \ + "mov %0, r0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_7W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[8]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "sub sp, sp, #4 \n\t" \ + "ldr r0, [%1, #20] \n\t" \ + "ldr r1, [%1, #24] \n\t" \ + "ldr r2, [%1, #28] \n\t" \ + "push {r0, r1, r2} \n\t" \ + "ldr r0, [%1, #4] \n\t" \ + "ldr r1, [%1, #8] \n\t" \ + "ldr r2, [%1, #12] \n\t" \ + "ldr r3, [%1, #16] \n\t" \ + "ldr r4, [%1] \n\t" /* target->r4 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4 \ + VALGRIND_RESTORE_STACK \ + "mov %0, r0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_8W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[9]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr r0, [%1, #20] \n\t" \ + "ldr r1, [%1, #24] \n\t" \ + "ldr r2, [%1, #28] \n\t" \ + "ldr r3, [%1, #32] \n\t" \ + "push {r0, r1, r2, r3} \n\t" \ + "ldr r0, [%1, #4] \n\t" \ + "ldr r1, [%1, #8] \n\t" \ + "ldr r2, [%1, #12] \n\t" \ + "ldr r3, [%1, #16] \n\t" \ + "ldr r4, [%1] \n\t" /* target->r4 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4 \ + VALGRIND_RESTORE_STACK \ + "mov %0, r0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_9W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[10]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "sub sp, sp, #4 \n\t" \ + "ldr r0, [%1, #20] \n\t" \ + "ldr r1, [%1, #24] \n\t" \ + "ldr r2, [%1, #28] \n\t" \ + "ldr r3, [%1, #32] \n\t" \ + "ldr r4, [%1, #36] \n\t" \ + "push {r0, r1, r2, r3, r4} \n\t" \ + "ldr r0, [%1, #4] \n\t" \ + "ldr r1, [%1, #8] \n\t" \ + "ldr r2, [%1, #12] \n\t" \ + "ldr r3, [%1, #16] \n\t" \ + "ldr r4, [%1] \n\t" /* target->r4 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4 \ + VALGRIND_RESTORE_STACK \ + "mov %0, r0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_10W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[11]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr r0, [%1, #40] \n\t" \ + "push {r0} \n\t" \ + "ldr r0, [%1, #20] \n\t" \ + "ldr r1, [%1, #24] \n\t" \ + "ldr r2, [%1, #28] \n\t" \ + "ldr r3, [%1, #32] \n\t" \ + "ldr r4, [%1, #36] \n\t" \ + "push {r0, r1, r2, r3, r4} \n\t" \ + "ldr r0, [%1, #4] \n\t" \ + "ldr r1, [%1, #8] \n\t" \ + "ldr r2, [%1, #12] \n\t" \ + "ldr r3, [%1, #16] \n\t" \ + "ldr r4, [%1] \n\t" /* target->r4 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4 \ + VALGRIND_RESTORE_STACK \ + "mov %0, r0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_11W(lval, orig, arg1,arg2,arg3,arg4,arg5, \ + arg6,arg7,arg8,arg9,arg10, \ + arg11) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[12]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + _argvec[11] = (unsigned long)(arg11); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "sub sp, sp, #4 \n\t" \ + "ldr r0, [%1, #40] \n\t" \ + "ldr r1, [%1, #44] \n\t" \ + "push {r0, r1} \n\t" \ + "ldr r0, [%1, #20] \n\t" \ + "ldr r1, [%1, #24] \n\t" \ + "ldr r2, [%1, #28] \n\t" \ + "ldr r3, [%1, #32] \n\t" \ + "ldr r4, [%1, #36] \n\t" \ + "push {r0, r1, r2, r3, r4} \n\t" \ + "ldr r0, [%1, #4] \n\t" \ + "ldr r1, [%1, #8] \n\t" \ + "ldr r2, [%1, #12] \n\t" \ + "ldr r3, [%1, #16] \n\t" \ + "ldr r4, [%1] \n\t" /* target->r4 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4 \ + VALGRIND_RESTORE_STACK \ + "mov %0, r0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_12W(lval, orig, arg1,arg2,arg3,arg4,arg5, \ + arg6,arg7,arg8,arg9,arg10, \ + arg11,arg12) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[13]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + _argvec[11] = (unsigned long)(arg11); \ + _argvec[12] = (unsigned long)(arg12); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr r0, [%1, #40] \n\t" \ + "ldr r1, [%1, #44] \n\t" \ + "ldr r2, [%1, #48] \n\t" \ + "push {r0, r1, r2} \n\t" \ + "ldr r0, [%1, #20] \n\t" \ + "ldr r1, [%1, #24] \n\t" \ + "ldr r2, [%1, #28] \n\t" \ + "ldr r3, [%1, #32] \n\t" \ + "ldr r4, [%1, #36] \n\t" \ + "push {r0, r1, r2, r3, r4} \n\t" \ + "ldr r0, [%1, #4] \n\t" \ + "ldr r1, [%1, #8] \n\t" \ + "ldr r2, [%1, #12] \n\t" \ + "ldr r3, [%1, #16] \n\t" \ + "ldr r4, [%1] \n\t" /* target->r4 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4 \ + VALGRIND_RESTORE_STACK \ + "mov %0, r0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#endif /* PLAT_arm_linux */ + +/* ------------------------ arm64-linux ------------------------ */ + +#if defined(PLAT_arm64_linux) + +/* These regs are trashed by the hidden call. */ +#define __CALLER_SAVED_REGS \ + "x0", "x1", "x2", "x3","x4", "x5", "x6", "x7", "x8", "x9", \ + "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17", \ + "x18", "x19", "x20", "x30", \ + "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", \ + "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", \ + "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", \ + "v26", "v27", "v28", "v29", "v30", "v31" + +/* x21 is callee-saved, so we can use it to save and restore SP around + the hidden call. */ +#define VALGRIND_ALIGN_STACK \ + "mov x21, sp\n\t" \ + "bic sp, x21, #15\n\t" +#define VALGRIND_RESTORE_STACK \ + "mov sp, x21\n\t" + +/* These CALL_FN_ macros assume that on arm64-linux, + sizeof(unsigned long) == 8. */ + +#define CALL_FN_W_v(lval, orig) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[1]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr x8, [%1] \n\t" /* target->x8 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8 \ + VALGRIND_RESTORE_STACK \ + "mov %0, x0\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_W(lval, orig, arg1) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[2]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr x0, [%1, #8] \n\t" \ + "ldr x8, [%1] \n\t" /* target->x8 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8 \ + VALGRIND_RESTORE_STACK \ + "mov %0, x0\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WW(lval, orig, arg1,arg2) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr x0, [%1, #8] \n\t" \ + "ldr x1, [%1, #16] \n\t" \ + "ldr x8, [%1] \n\t" /* target->x8 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8 \ + VALGRIND_RESTORE_STACK \ + "mov %0, x0\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWW(lval, orig, arg1,arg2,arg3) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[4]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr x0, [%1, #8] \n\t" \ + "ldr x1, [%1, #16] \n\t" \ + "ldr x2, [%1, #24] \n\t" \ + "ldr x8, [%1] \n\t" /* target->x8 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8 \ + VALGRIND_RESTORE_STACK \ + "mov %0, x0\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWWW(lval, orig, arg1,arg2,arg3,arg4) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[5]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr x0, [%1, #8] \n\t" \ + "ldr x1, [%1, #16] \n\t" \ + "ldr x2, [%1, #24] \n\t" \ + "ldr x3, [%1, #32] \n\t" \ + "ldr x8, [%1] \n\t" /* target->x8 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8 \ + VALGRIND_RESTORE_STACK \ + "mov %0, x0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_5W(lval, orig, arg1,arg2,arg3,arg4,arg5) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[6]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr x0, [%1, #8] \n\t" \ + "ldr x1, [%1, #16] \n\t" \ + "ldr x2, [%1, #24] \n\t" \ + "ldr x3, [%1, #32] \n\t" \ + "ldr x4, [%1, #40] \n\t" \ + "ldr x8, [%1] \n\t" /* target->x8 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8 \ + VALGRIND_RESTORE_STACK \ + "mov %0, x0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_6W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[7]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr x0, [%1, #8] \n\t" \ + "ldr x1, [%1, #16] \n\t" \ + "ldr x2, [%1, #24] \n\t" \ + "ldr x3, [%1, #32] \n\t" \ + "ldr x4, [%1, #40] \n\t" \ + "ldr x5, [%1, #48] \n\t" \ + "ldr x8, [%1] \n\t" /* target->x8 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8 \ + VALGRIND_RESTORE_STACK \ + "mov %0, x0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_7W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[8]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr x0, [%1, #8] \n\t" \ + "ldr x1, [%1, #16] \n\t" \ + "ldr x2, [%1, #24] \n\t" \ + "ldr x3, [%1, #32] \n\t" \ + "ldr x4, [%1, #40] \n\t" \ + "ldr x5, [%1, #48] \n\t" \ + "ldr x6, [%1, #56] \n\t" \ + "ldr x8, [%1] \n\t" /* target->x8 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8 \ + VALGRIND_RESTORE_STACK \ + "mov %0, x0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_8W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[9]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr x0, [%1, #8] \n\t" \ + "ldr x1, [%1, #16] \n\t" \ + "ldr x2, [%1, #24] \n\t" \ + "ldr x3, [%1, #32] \n\t" \ + "ldr x4, [%1, #40] \n\t" \ + "ldr x5, [%1, #48] \n\t" \ + "ldr x6, [%1, #56] \n\t" \ + "ldr x7, [%1, #64] \n\t" \ + "ldr x8, [%1] \n\t" /* target->x8 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8 \ + VALGRIND_RESTORE_STACK \ + "mov %0, x0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_9W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[10]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "sub sp, sp, #0x20 \n\t" \ + "ldr x0, [%1, #8] \n\t" \ + "ldr x1, [%1, #16] \n\t" \ + "ldr x2, [%1, #24] \n\t" \ + "ldr x3, [%1, #32] \n\t" \ + "ldr x4, [%1, #40] \n\t" \ + "ldr x5, [%1, #48] \n\t" \ + "ldr x6, [%1, #56] \n\t" \ + "ldr x7, [%1, #64] \n\t" \ + "ldr x8, [%1, #72] \n\t" \ + "str x8, [sp, #0] \n\t" \ + "ldr x8, [%1] \n\t" /* target->x8 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8 \ + VALGRIND_RESTORE_STACK \ + "mov %0, x0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_10W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[11]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "sub sp, sp, #0x20 \n\t" \ + "ldr x0, [%1, #8] \n\t" \ + "ldr x1, [%1, #16] \n\t" \ + "ldr x2, [%1, #24] \n\t" \ + "ldr x3, [%1, #32] \n\t" \ + "ldr x4, [%1, #40] \n\t" \ + "ldr x5, [%1, #48] \n\t" \ + "ldr x6, [%1, #56] \n\t" \ + "ldr x7, [%1, #64] \n\t" \ + "ldr x8, [%1, #72] \n\t" \ + "str x8, [sp, #0] \n\t" \ + "ldr x8, [%1, #80] \n\t" \ + "str x8, [sp, #8] \n\t" \ + "ldr x8, [%1] \n\t" /* target->x8 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8 \ + VALGRIND_RESTORE_STACK \ + "mov %0, x0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_11W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10,arg11) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[12]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + _argvec[11] = (unsigned long)(arg11); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "sub sp, sp, #0x30 \n\t" \ + "ldr x0, [%1, #8] \n\t" \ + "ldr x1, [%1, #16] \n\t" \ + "ldr x2, [%1, #24] \n\t" \ + "ldr x3, [%1, #32] \n\t" \ + "ldr x4, [%1, #40] \n\t" \ + "ldr x5, [%1, #48] \n\t" \ + "ldr x6, [%1, #56] \n\t" \ + "ldr x7, [%1, #64] \n\t" \ + "ldr x8, [%1, #72] \n\t" \ + "str x8, [sp, #0] \n\t" \ + "ldr x8, [%1, #80] \n\t" \ + "str x8, [sp, #8] \n\t" \ + "ldr x8, [%1, #88] \n\t" \ + "str x8, [sp, #16] \n\t" \ + "ldr x8, [%1] \n\t" /* target->x8 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8 \ + VALGRIND_RESTORE_STACK \ + "mov %0, x0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_12W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10,arg11, \ + arg12) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[13]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + _argvec[11] = (unsigned long)(arg11); \ + _argvec[12] = (unsigned long)(arg12); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "sub sp, sp, #0x30 \n\t" \ + "ldr x0, [%1, #8] \n\t" \ + "ldr x1, [%1, #16] \n\t" \ + "ldr x2, [%1, #24] \n\t" \ + "ldr x3, [%1, #32] \n\t" \ + "ldr x4, [%1, #40] \n\t" \ + "ldr x5, [%1, #48] \n\t" \ + "ldr x6, [%1, #56] \n\t" \ + "ldr x7, [%1, #64] \n\t" \ + "ldr x8, [%1, #72] \n\t" \ + "str x8, [sp, #0] \n\t" \ + "ldr x8, [%1, #80] \n\t" \ + "str x8, [sp, #8] \n\t" \ + "ldr x8, [%1, #88] \n\t" \ + "str x8, [sp, #16] \n\t" \ + "ldr x8, [%1, #96] \n\t" \ + "str x8, [sp, #24] \n\t" \ + "ldr x8, [%1] \n\t" /* target->x8 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8 \ + VALGRIND_RESTORE_STACK \ + "mov %0, x0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#endif /* PLAT_arm64_linux */ + +/* ------------------------- s390x-linux ------------------------- */ + +#if defined(PLAT_s390x_linux) + +/* Similar workaround as amd64 (see above), but we use r11 as frame + pointer and save the old r11 in r7. r11 might be used for + argvec, therefore we copy argvec in r1 since r1 is clobbered + after the call anyway. */ +#if defined(__GNUC__) && defined(__GCC_HAVE_DWARF2_CFI_ASM) +# define __FRAME_POINTER \ + ,"d"(__builtin_dwarf_cfa()) +# define VALGRIND_CFI_PROLOGUE \ + ".cfi_remember_state\n\t" \ + "lgr 1,%1\n\t" /* copy the argvec pointer in r1 */ \ + "lgr 7,11\n\t" \ + "lgr 11,%2\n\t" \ + ".cfi_def_cfa r11, 0\n\t" +# define VALGRIND_CFI_EPILOGUE \ + "lgr 11, 7\n\t" \ + ".cfi_restore_state\n\t" +#else +# define __FRAME_POINTER +# define VALGRIND_CFI_PROLOGUE \ + "lgr 1,%1\n\t" +# define VALGRIND_CFI_EPILOGUE +#endif + +/* Nb: On s390 the stack pointer is properly aligned *at all times* + according to the s390 GCC maintainer. (The ABI specification is not + precise in this regard.) Therefore, VALGRIND_ALIGN_STACK and + VALGRIND_RESTORE_STACK are not defined here. */ + +/* These regs are trashed by the hidden call. Note that we overwrite + r14 in s390_irgen_noredir (VEX/priv/guest_s390_irgen.c) to give the + function a proper return address. All others are ABI defined call + clobbers. */ +#if defined(__VX__) || defined(__S390_VX__) +#define __CALLER_SAVED_REGS "0", "1", "2", "3", "4", "5", "14", \ + "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", \ + "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", \ + "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", \ + "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" +#else +#define __CALLER_SAVED_REGS "0", "1", "2", "3", "4", "5", "14", \ + "f0", "f1", "f2", "f3", "f4", "f5", "f6", "f7" +#endif + +/* Nb: Although r11 is modified in the asm snippets below (inside + VALGRIND_CFI_PROLOGUE) it is not listed in the clobber section, for + two reasons: + (1) r11 is restored in VALGRIND_CFI_EPILOGUE, so effectively it is not + modified + (2) GCC will complain that r11 cannot appear inside a clobber section, + when compiled with -O -fno-omit-frame-pointer + */ + +#define CALL_FN_W_v(lval, orig) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[1]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + "aghi 15,-160\n\t" \ + "lg 1, 0(1)\n\t" /* target->r1 */ \ + VALGRIND_CALL_NOREDIR_R1 \ + "aghi 15,160\n\t" \ + VALGRIND_CFI_EPILOGUE \ + "lgr %0, 2\n\t" \ + : /*out*/ "=d" (_res) \ + : /*in*/ "d" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"7" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +/* The call abi has the arguments in r2-r6 and stack */ +#define CALL_FN_W_W(lval, orig, arg1) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[2]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + "aghi 15,-160\n\t" \ + "lg 2, 8(1)\n\t" \ + "lg 1, 0(1)\n\t" \ + VALGRIND_CALL_NOREDIR_R1 \ + "aghi 15,160\n\t" \ + VALGRIND_CFI_EPILOGUE \ + "lgr %0, 2\n\t" \ + : /*out*/ "=d" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"7" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WW(lval, orig, arg1, arg2) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + "aghi 15,-160\n\t" \ + "lg 2, 8(1)\n\t" \ + "lg 3,16(1)\n\t" \ + "lg 1, 0(1)\n\t" \ + VALGRIND_CALL_NOREDIR_R1 \ + "aghi 15,160\n\t" \ + VALGRIND_CFI_EPILOGUE \ + "lgr %0, 2\n\t" \ + : /*out*/ "=d" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"7" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWW(lval, orig, arg1, arg2, arg3) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[4]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + "aghi 15,-160\n\t" \ + "lg 2, 8(1)\n\t" \ + "lg 3,16(1)\n\t" \ + "lg 4,24(1)\n\t" \ + "lg 1, 0(1)\n\t" \ + VALGRIND_CALL_NOREDIR_R1 \ + "aghi 15,160\n\t" \ + VALGRIND_CFI_EPILOGUE \ + "lgr %0, 2\n\t" \ + : /*out*/ "=d" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"7" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWWW(lval, orig, arg1, arg2, arg3, arg4) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[5]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + "aghi 15,-160\n\t" \ + "lg 2, 8(1)\n\t" \ + "lg 3,16(1)\n\t" \ + "lg 4,24(1)\n\t" \ + "lg 5,32(1)\n\t" \ + "lg 1, 0(1)\n\t" \ + VALGRIND_CALL_NOREDIR_R1 \ + "aghi 15,160\n\t" \ + VALGRIND_CFI_EPILOGUE \ + "lgr %0, 2\n\t" \ + : /*out*/ "=d" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"7" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_5W(lval, orig, arg1, arg2, arg3, arg4, arg5) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[6]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + _argvec[5] = (unsigned long)arg5; \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + "aghi 15,-160\n\t" \ + "lg 2, 8(1)\n\t" \ + "lg 3,16(1)\n\t" \ + "lg 4,24(1)\n\t" \ + "lg 5,32(1)\n\t" \ + "lg 6,40(1)\n\t" \ + "lg 1, 0(1)\n\t" \ + VALGRIND_CALL_NOREDIR_R1 \ + "aghi 15,160\n\t" \ + VALGRIND_CFI_EPILOGUE \ + "lgr %0, 2\n\t" \ + : /*out*/ "=d" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"6","7" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_6W(lval, orig, arg1, arg2, arg3, arg4, arg5, \ + arg6) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[7]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + _argvec[5] = (unsigned long)arg5; \ + _argvec[6] = (unsigned long)arg6; \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + "aghi 15,-168\n\t" \ + "lg 2, 8(1)\n\t" \ + "lg 3,16(1)\n\t" \ + "lg 4,24(1)\n\t" \ + "lg 5,32(1)\n\t" \ + "lg 6,40(1)\n\t" \ + "mvc 160(8,15), 48(1)\n\t" \ + "lg 1, 0(1)\n\t" \ + VALGRIND_CALL_NOREDIR_R1 \ + "aghi 15,168\n\t" \ + VALGRIND_CFI_EPILOGUE \ + "lgr %0, 2\n\t" \ + : /*out*/ "=d" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"6","7" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_7W(lval, orig, arg1, arg2, arg3, arg4, arg5, \ + arg6, arg7) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[8]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + _argvec[5] = (unsigned long)arg5; \ + _argvec[6] = (unsigned long)arg6; \ + _argvec[7] = (unsigned long)arg7; \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + "aghi 15,-176\n\t" \ + "lg 2, 8(1)\n\t" \ + "lg 3,16(1)\n\t" \ + "lg 4,24(1)\n\t" \ + "lg 5,32(1)\n\t" \ + "lg 6,40(1)\n\t" \ + "mvc 160(8,15), 48(1)\n\t" \ + "mvc 168(8,15), 56(1)\n\t" \ + "lg 1, 0(1)\n\t" \ + VALGRIND_CALL_NOREDIR_R1 \ + "aghi 15,176\n\t" \ + VALGRIND_CFI_EPILOGUE \ + "lgr %0, 2\n\t" \ + : /*out*/ "=d" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"6","7" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_8W(lval, orig, arg1, arg2, arg3, arg4, arg5, \ + arg6, arg7 ,arg8) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[9]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + _argvec[5] = (unsigned long)arg5; \ + _argvec[6] = (unsigned long)arg6; \ + _argvec[7] = (unsigned long)arg7; \ + _argvec[8] = (unsigned long)arg8; \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + "aghi 15,-184\n\t" \ + "lg 2, 8(1)\n\t" \ + "lg 3,16(1)\n\t" \ + "lg 4,24(1)\n\t" \ + "lg 5,32(1)\n\t" \ + "lg 6,40(1)\n\t" \ + "mvc 160(8,15), 48(1)\n\t" \ + "mvc 168(8,15), 56(1)\n\t" \ + "mvc 176(8,15), 64(1)\n\t" \ + "lg 1, 0(1)\n\t" \ + VALGRIND_CALL_NOREDIR_R1 \ + "aghi 15,184\n\t" \ + VALGRIND_CFI_EPILOGUE \ + "lgr %0, 2\n\t" \ + : /*out*/ "=d" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"6","7" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_9W(lval, orig, arg1, arg2, arg3, arg4, arg5, \ + arg6, arg7 ,arg8, arg9) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[10]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + _argvec[5] = (unsigned long)arg5; \ + _argvec[6] = (unsigned long)arg6; \ + _argvec[7] = (unsigned long)arg7; \ + _argvec[8] = (unsigned long)arg8; \ + _argvec[9] = (unsigned long)arg9; \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + "aghi 15,-192\n\t" \ + "lg 2, 8(1)\n\t" \ + "lg 3,16(1)\n\t" \ + "lg 4,24(1)\n\t" \ + "lg 5,32(1)\n\t" \ + "lg 6,40(1)\n\t" \ + "mvc 160(8,15), 48(1)\n\t" \ + "mvc 168(8,15), 56(1)\n\t" \ + "mvc 176(8,15), 64(1)\n\t" \ + "mvc 184(8,15), 72(1)\n\t" \ + "lg 1, 0(1)\n\t" \ + VALGRIND_CALL_NOREDIR_R1 \ + "aghi 15,192\n\t" \ + VALGRIND_CFI_EPILOGUE \ + "lgr %0, 2\n\t" \ + : /*out*/ "=d" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"6","7" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_10W(lval, orig, arg1, arg2, arg3, arg4, arg5, \ + arg6, arg7 ,arg8, arg9, arg10) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[11]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + _argvec[5] = (unsigned long)arg5; \ + _argvec[6] = (unsigned long)arg6; \ + _argvec[7] = (unsigned long)arg7; \ + _argvec[8] = (unsigned long)arg8; \ + _argvec[9] = (unsigned long)arg9; \ + _argvec[10] = (unsigned long)arg10; \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + "aghi 15,-200\n\t" \ + "lg 2, 8(1)\n\t" \ + "lg 3,16(1)\n\t" \ + "lg 4,24(1)\n\t" \ + "lg 5,32(1)\n\t" \ + "lg 6,40(1)\n\t" \ + "mvc 160(8,15), 48(1)\n\t" \ + "mvc 168(8,15), 56(1)\n\t" \ + "mvc 176(8,15), 64(1)\n\t" \ + "mvc 184(8,15), 72(1)\n\t" \ + "mvc 192(8,15), 80(1)\n\t" \ + "lg 1, 0(1)\n\t" \ + VALGRIND_CALL_NOREDIR_R1 \ + "aghi 15,200\n\t" \ + VALGRIND_CFI_EPILOGUE \ + "lgr %0, 2\n\t" \ + : /*out*/ "=d" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"6","7" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_11W(lval, orig, arg1, arg2, arg3, arg4, arg5, \ + arg6, arg7 ,arg8, arg9, arg10, arg11) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[12]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + _argvec[5] = (unsigned long)arg5; \ + _argvec[6] = (unsigned long)arg6; \ + _argvec[7] = (unsigned long)arg7; \ + _argvec[8] = (unsigned long)arg8; \ + _argvec[9] = (unsigned long)arg9; \ + _argvec[10] = (unsigned long)arg10; \ + _argvec[11] = (unsigned long)arg11; \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + "aghi 15,-208\n\t" \ + "lg 2, 8(1)\n\t" \ + "lg 3,16(1)\n\t" \ + "lg 4,24(1)\n\t" \ + "lg 5,32(1)\n\t" \ + "lg 6,40(1)\n\t" \ + "mvc 160(8,15), 48(1)\n\t" \ + "mvc 168(8,15), 56(1)\n\t" \ + "mvc 176(8,15), 64(1)\n\t" \ + "mvc 184(8,15), 72(1)\n\t" \ + "mvc 192(8,15), 80(1)\n\t" \ + "mvc 200(8,15), 88(1)\n\t" \ + "lg 1, 0(1)\n\t" \ + VALGRIND_CALL_NOREDIR_R1 \ + "aghi 15,208\n\t" \ + VALGRIND_CFI_EPILOGUE \ + "lgr %0, 2\n\t" \ + : /*out*/ "=d" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"6","7" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_12W(lval, orig, arg1, arg2, arg3, arg4, arg5, \ + arg6, arg7 ,arg8, arg9, arg10, arg11, arg12)\ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[13]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + _argvec[5] = (unsigned long)arg5; \ + _argvec[6] = (unsigned long)arg6; \ + _argvec[7] = (unsigned long)arg7; \ + _argvec[8] = (unsigned long)arg8; \ + _argvec[9] = (unsigned long)arg9; \ + _argvec[10] = (unsigned long)arg10; \ + _argvec[11] = (unsigned long)arg11; \ + _argvec[12] = (unsigned long)arg12; \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + "aghi 15,-216\n\t" \ + "lg 2, 8(1)\n\t" \ + "lg 3,16(1)\n\t" \ + "lg 4,24(1)\n\t" \ + "lg 5,32(1)\n\t" \ + "lg 6,40(1)\n\t" \ + "mvc 160(8,15), 48(1)\n\t" \ + "mvc 168(8,15), 56(1)\n\t" \ + "mvc 176(8,15), 64(1)\n\t" \ + "mvc 184(8,15), 72(1)\n\t" \ + "mvc 192(8,15), 80(1)\n\t" \ + "mvc 200(8,15), 88(1)\n\t" \ + "mvc 208(8,15), 96(1)\n\t" \ + "lg 1, 0(1)\n\t" \ + VALGRIND_CALL_NOREDIR_R1 \ + "aghi 15,216\n\t" \ + VALGRIND_CFI_EPILOGUE \ + "lgr %0, 2\n\t" \ + : /*out*/ "=d" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"6","7" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + + +#endif /* PLAT_s390x_linux */ + +/* ------------------------- mips32-linux ----------------------- */ + +#if defined(PLAT_mips32_linux) + +/* These regs are trashed by the hidden call. */ +#define __CALLER_SAVED_REGS "$2", "$3", "$4", "$5", "$6", \ +"$7", "$8", "$9", "$10", "$11", "$12", "$13", "$14", "$15", "$24", \ +"$25", "$31" + +/* These CALL_FN_ macros assume that on mips-linux, sizeof(unsigned + long) == 4. */ + +#define CALL_FN_W_v(lval, orig) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[1]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + __asm__ volatile( \ + "subu $29, $29, 8 \n\t" \ + "sw $28, 0($29) \n\t" \ + "sw $31, 4($29) \n\t" \ + "subu $29, $29, 16 \n\t" \ + "lw $25, 0(%1) \n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "addu $29, $29, 16\n\t" \ + "lw $28, 0($29) \n\t" \ + "lw $31, 4($29) \n\t" \ + "addu $29, $29, 8 \n\t" \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_W(lval, orig, arg1) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[2]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + __asm__ volatile( \ + "subu $29, $29, 8 \n\t" \ + "sw $28, 0($29) \n\t" \ + "sw $31, 4($29) \n\t" \ + "subu $29, $29, 16 \n\t" \ + "lw $4, 4(%1) \n\t" /* arg1*/ \ + "lw $25, 0(%1) \n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "addu $29, $29, 16 \n\t" \ + "lw $28, 0($29) \n\t" \ + "lw $31, 4($29) \n\t" \ + "addu $29, $29, 8 \n\t" \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WW(lval, orig, arg1,arg2) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + __asm__ volatile( \ + "subu $29, $29, 8 \n\t" \ + "sw $28, 0($29) \n\t" \ + "sw $31, 4($29) \n\t" \ + "subu $29, $29, 16 \n\t" \ + "lw $4, 4(%1) \n\t" \ + "lw $5, 8(%1) \n\t" \ + "lw $25, 0(%1) \n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "addu $29, $29, 16 \n\t" \ + "lw $28, 0($29) \n\t" \ + "lw $31, 4($29) \n\t" \ + "addu $29, $29, 8 \n\t" \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWW(lval, orig, arg1,arg2,arg3) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[4]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + __asm__ volatile( \ + "subu $29, $29, 8 \n\t" \ + "sw $28, 0($29) \n\t" \ + "sw $31, 4($29) \n\t" \ + "subu $29, $29, 16 \n\t" \ + "lw $4, 4(%1) \n\t" \ + "lw $5, 8(%1) \n\t" \ + "lw $6, 12(%1) \n\t" \ + "lw $25, 0(%1) \n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "addu $29, $29, 16 \n\t" \ + "lw $28, 0($29) \n\t" \ + "lw $31, 4($29) \n\t" \ + "addu $29, $29, 8 \n\t" \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWWW(lval, orig, arg1,arg2,arg3,arg4) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[5]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + __asm__ volatile( \ + "subu $29, $29, 8 \n\t" \ + "sw $28, 0($29) \n\t" \ + "sw $31, 4($29) \n\t" \ + "subu $29, $29, 16 \n\t" \ + "lw $4, 4(%1) \n\t" \ + "lw $5, 8(%1) \n\t" \ + "lw $6, 12(%1) \n\t" \ + "lw $7, 16(%1) \n\t" \ + "lw $25, 0(%1) \n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "addu $29, $29, 16 \n\t" \ + "lw $28, 0($29) \n\t" \ + "lw $31, 4($29) \n\t" \ + "addu $29, $29, 8 \n\t" \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_5W(lval, orig, arg1,arg2,arg3,arg4,arg5) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[6]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + __asm__ volatile( \ + "subu $29, $29, 8 \n\t" \ + "sw $28, 0($29) \n\t" \ + "sw $31, 4($29) \n\t" \ + "lw $4, 20(%1) \n\t" \ + "subu $29, $29, 24\n\t" \ + "sw $4, 16($29) \n\t" \ + "lw $4, 4(%1) \n\t" \ + "lw $5, 8(%1) \n\t" \ + "lw $6, 12(%1) \n\t" \ + "lw $7, 16(%1) \n\t" \ + "lw $25, 0(%1) \n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "addu $29, $29, 24 \n\t" \ + "lw $28, 0($29) \n\t" \ + "lw $31, 4($29) \n\t" \ + "addu $29, $29, 8 \n\t" \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) +#define CALL_FN_W_6W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[7]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + __asm__ volatile( \ + "subu $29, $29, 8 \n\t" \ + "sw $28, 0($29) \n\t" \ + "sw $31, 4($29) \n\t" \ + "lw $4, 20(%1) \n\t" \ + "subu $29, $29, 32\n\t" \ + "sw $4, 16($29) \n\t" \ + "lw $4, 24(%1) \n\t" \ + "nop\n\t" \ + "sw $4, 20($29) \n\t" \ + "lw $4, 4(%1) \n\t" \ + "lw $5, 8(%1) \n\t" \ + "lw $6, 12(%1) \n\t" \ + "lw $7, 16(%1) \n\t" \ + "lw $25, 0(%1) \n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "addu $29, $29, 32 \n\t" \ + "lw $28, 0($29) \n\t" \ + "lw $31, 4($29) \n\t" \ + "addu $29, $29, 8 \n\t" \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_7W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[8]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + __asm__ volatile( \ + "subu $29, $29, 8 \n\t" \ + "sw $28, 0($29) \n\t" \ + "sw $31, 4($29) \n\t" \ + "lw $4, 20(%1) \n\t" \ + "subu $29, $29, 32\n\t" \ + "sw $4, 16($29) \n\t" \ + "lw $4, 24(%1) \n\t" \ + "sw $4, 20($29) \n\t" \ + "lw $4, 28(%1) \n\t" \ + "sw $4, 24($29) \n\t" \ + "lw $4, 4(%1) \n\t" \ + "lw $5, 8(%1) \n\t" \ + "lw $6, 12(%1) \n\t" \ + "lw $7, 16(%1) \n\t" \ + "lw $25, 0(%1) \n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "addu $29, $29, 32 \n\t" \ + "lw $28, 0($29) \n\t" \ + "lw $31, 4($29) \n\t" \ + "addu $29, $29, 8 \n\t" \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_8W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[9]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + __asm__ volatile( \ + "subu $29, $29, 8 \n\t" \ + "sw $28, 0($29) \n\t" \ + "sw $31, 4($29) \n\t" \ + "lw $4, 20(%1) \n\t" \ + "subu $29, $29, 40\n\t" \ + "sw $4, 16($29) \n\t" \ + "lw $4, 24(%1) \n\t" \ + "sw $4, 20($29) \n\t" \ + "lw $4, 28(%1) \n\t" \ + "sw $4, 24($29) \n\t" \ + "lw $4, 32(%1) \n\t" \ + "sw $4, 28($29) \n\t" \ + "lw $4, 4(%1) \n\t" \ + "lw $5, 8(%1) \n\t" \ + "lw $6, 12(%1) \n\t" \ + "lw $7, 16(%1) \n\t" \ + "lw $25, 0(%1) \n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "addu $29, $29, 40 \n\t" \ + "lw $28, 0($29) \n\t" \ + "lw $31, 4($29) \n\t" \ + "addu $29, $29, 8 \n\t" \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_9W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[10]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + __asm__ volatile( \ + "subu $29, $29, 8 \n\t" \ + "sw $28, 0($29) \n\t" \ + "sw $31, 4($29) \n\t" \ + "lw $4, 20(%1) \n\t" \ + "subu $29, $29, 40\n\t" \ + "sw $4, 16($29) \n\t" \ + "lw $4, 24(%1) \n\t" \ + "sw $4, 20($29) \n\t" \ + "lw $4, 28(%1) \n\t" \ + "sw $4, 24($29) \n\t" \ + "lw $4, 32(%1) \n\t" \ + "sw $4, 28($29) \n\t" \ + "lw $4, 36(%1) \n\t" \ + "sw $4, 32($29) \n\t" \ + "lw $4, 4(%1) \n\t" \ + "lw $5, 8(%1) \n\t" \ + "lw $6, 12(%1) \n\t" \ + "lw $7, 16(%1) \n\t" \ + "lw $25, 0(%1) \n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "addu $29, $29, 40 \n\t" \ + "lw $28, 0($29) \n\t" \ + "lw $31, 4($29) \n\t" \ + "addu $29, $29, 8 \n\t" \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_10W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[11]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + __asm__ volatile( \ + "subu $29, $29, 8 \n\t" \ + "sw $28, 0($29) \n\t" \ + "sw $31, 4($29) \n\t" \ + "lw $4, 20(%1) \n\t" \ + "subu $29, $29, 48\n\t" \ + "sw $4, 16($29) \n\t" \ + "lw $4, 24(%1) \n\t" \ + "sw $4, 20($29) \n\t" \ + "lw $4, 28(%1) \n\t" \ + "sw $4, 24($29) \n\t" \ + "lw $4, 32(%1) \n\t" \ + "sw $4, 28($29) \n\t" \ + "lw $4, 36(%1) \n\t" \ + "sw $4, 32($29) \n\t" \ + "lw $4, 40(%1) \n\t" \ + "sw $4, 36($29) \n\t" \ + "lw $4, 4(%1) \n\t" \ + "lw $5, 8(%1) \n\t" \ + "lw $6, 12(%1) \n\t" \ + "lw $7, 16(%1) \n\t" \ + "lw $25, 0(%1) \n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "addu $29, $29, 48 \n\t" \ + "lw $28, 0($29) \n\t" \ + "lw $31, 4($29) \n\t" \ + "addu $29, $29, 8 \n\t" \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_11W(lval, orig, arg1,arg2,arg3,arg4,arg5, \ + arg6,arg7,arg8,arg9,arg10, \ + arg11) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[12]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + _argvec[11] = (unsigned long)(arg11); \ + __asm__ volatile( \ + "subu $29, $29, 8 \n\t" \ + "sw $28, 0($29) \n\t" \ + "sw $31, 4($29) \n\t" \ + "lw $4, 20(%1) \n\t" \ + "subu $29, $29, 48\n\t" \ + "sw $4, 16($29) \n\t" \ + "lw $4, 24(%1) \n\t" \ + "sw $4, 20($29) \n\t" \ + "lw $4, 28(%1) \n\t" \ + "sw $4, 24($29) \n\t" \ + "lw $4, 32(%1) \n\t" \ + "sw $4, 28($29) \n\t" \ + "lw $4, 36(%1) \n\t" \ + "sw $4, 32($29) \n\t" \ + "lw $4, 40(%1) \n\t" \ + "sw $4, 36($29) \n\t" \ + "lw $4, 44(%1) \n\t" \ + "sw $4, 40($29) \n\t" \ + "lw $4, 4(%1) \n\t" \ + "lw $5, 8(%1) \n\t" \ + "lw $6, 12(%1) \n\t" \ + "lw $7, 16(%1) \n\t" \ + "lw $25, 0(%1) \n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "addu $29, $29, 48 \n\t" \ + "lw $28, 0($29) \n\t" \ + "lw $31, 4($29) \n\t" \ + "addu $29, $29, 8 \n\t" \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_12W(lval, orig, arg1,arg2,arg3,arg4,arg5, \ + arg6,arg7,arg8,arg9,arg10, \ + arg11,arg12) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[13]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + _argvec[11] = (unsigned long)(arg11); \ + _argvec[12] = (unsigned long)(arg12); \ + __asm__ volatile( \ + "subu $29, $29, 8 \n\t" \ + "sw $28, 0($29) \n\t" \ + "sw $31, 4($29) \n\t" \ + "lw $4, 20(%1) \n\t" \ + "subu $29, $29, 56\n\t" \ + "sw $4, 16($29) \n\t" \ + "lw $4, 24(%1) \n\t" \ + "sw $4, 20($29) \n\t" \ + "lw $4, 28(%1) \n\t" \ + "sw $4, 24($29) \n\t" \ + "lw $4, 32(%1) \n\t" \ + "sw $4, 28($29) \n\t" \ + "lw $4, 36(%1) \n\t" \ + "sw $4, 32($29) \n\t" \ + "lw $4, 40(%1) \n\t" \ + "sw $4, 36($29) \n\t" \ + "lw $4, 44(%1) \n\t" \ + "sw $4, 40($29) \n\t" \ + "lw $4, 48(%1) \n\t" \ + "sw $4, 44($29) \n\t" \ + "lw $4, 4(%1) \n\t" \ + "lw $5, 8(%1) \n\t" \ + "lw $6, 12(%1) \n\t" \ + "lw $7, 16(%1) \n\t" \ + "lw $25, 0(%1) \n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "addu $29, $29, 56 \n\t" \ + "lw $28, 0($29) \n\t" \ + "lw $31, 4($29) \n\t" \ + "addu $29, $29, 8 \n\t" \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#endif /* PLAT_mips32_linux */ + +/* ------------------------- nanomips-linux -------------------- */ + +#if defined(PLAT_nanomips_linux) + +/* These regs are trashed by the hidden call. */ +#define __CALLER_SAVED_REGS "$t4", "$t5", "$a0", "$a1", "$a2", \ +"$a3", "$a4", "$a5", "$a6", "$a7", "$t0", "$t1", "$t2", "$t3", \ +"$t8","$t9", "$at" + +/* These CALL_FN_ macros assume that on mips-linux, sizeof(unsigned + long) == 4. */ + +#define CALL_FN_W_v(lval, orig) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[1]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + __asm__ volatile( \ + "lw $t9, 0(%1)\n\t" \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $a0\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_W(lval, orig, arg1) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[2]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + __asm__ volatile( \ + "lw $t9, 0(%1)\n\t" \ + "lw $a0, 4(%1)\n\t" \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $a0\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WW(lval, orig, arg1,arg2) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + __asm__ volatile( \ + "lw $t9, 0(%1)\n\t" \ + "lw $a0, 4(%1)\n\t" \ + "lw $a1, 8(%1)\n\t" \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $a0\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWW(lval, orig, arg1,arg2,arg3) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[4]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + __asm__ volatile( \ + "lw $t9, 0(%1)\n\t" \ + "lw $a0, 4(%1)\n\t" \ + "lw $a1, 8(%1)\n\t" \ + "lw $a2,12(%1)\n\t" \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $a0\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWWW(lval, orig, arg1,arg2,arg3,arg4) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[5]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + __asm__ volatile( \ + "lw $t9, 0(%1)\n\t" \ + "lw $a0, 4(%1)\n\t" \ + "lw $a1, 8(%1)\n\t" \ + "lw $a2,12(%1)\n\t" \ + "lw $a3,16(%1)\n\t" \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $a0\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_5W(lval, orig, arg1,arg2,arg3,arg4,arg5) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[6]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + __asm__ volatile( \ + "lw $t9, 0(%1)\n\t" \ + "lw $a0, 4(%1)\n\t" \ + "lw $a1, 8(%1)\n\t" \ + "lw $a2,12(%1)\n\t" \ + "lw $a3,16(%1)\n\t" \ + "lw $a4,20(%1)\n\t" \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $a0\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) +#define CALL_FN_W_6W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[7]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + __asm__ volatile( \ + "lw $t9, 0(%1)\n\t" \ + "lw $a0, 4(%1)\n\t" \ + "lw $a1, 8(%1)\n\t" \ + "lw $a2,12(%1)\n\t" \ + "lw $a3,16(%1)\n\t" \ + "lw $a4,20(%1)\n\t" \ + "lw $a5,24(%1)\n\t" \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $a0\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_7W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[8]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + __asm__ volatile( \ + "lw $t9, 0(%1)\n\t" \ + "lw $a0, 4(%1)\n\t" \ + "lw $a1, 8(%1)\n\t" \ + "lw $a2,12(%1)\n\t" \ + "lw $a3,16(%1)\n\t" \ + "lw $a4,20(%1)\n\t" \ + "lw $a5,24(%1)\n\t" \ + "lw $a6,28(%1)\n\t" \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $a0\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_8W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[9]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + __asm__ volatile( \ + "lw $t9, 0(%1)\n\t" \ + "lw $a0, 4(%1)\n\t" \ + "lw $a1, 8(%1)\n\t" \ + "lw $a2,12(%1)\n\t" \ + "lw $a3,16(%1)\n\t" \ + "lw $a4,20(%1)\n\t" \ + "lw $a5,24(%1)\n\t" \ + "lw $a6,28(%1)\n\t" \ + "lw $a7,32(%1)\n\t" \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $a0\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_9W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[10]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + __asm__ volatile( \ + "addiu $sp, $sp, -16 \n\t" \ + "lw $t9,36(%1) \n\t" \ + "sw $t9, 0($sp) \n\t" \ + "lw $t9, 0(%1) \n\t" \ + "lw $a0, 4(%1) \n\t" \ + "lw $a1, 8(%1) \n\t" \ + "lw $a2,12(%1) \n\t" \ + "lw $a3,16(%1) \n\t" \ + "lw $a4,20(%1) \n\t" \ + "lw $a5,24(%1) \n\t" \ + "lw $a6,28(%1) \n\t" \ + "lw $a7,32(%1) \n\t" \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $a0 \n\t" \ + "addiu $sp, $sp, 16 \n\t" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_10W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[11]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + __asm__ volatile( \ + "addiu $sp, $sp, -16 \n\t" \ + "lw $t9,36(%1) \n\t" \ + "sw $t9, 0($sp) \n\t" \ + "lw $t9,40(%1) \n\t" \ + "sw $t9, 4($sp) \n\t" \ + "lw $t9, 0(%1) \n\t" \ + "lw $a0, 4(%1) \n\t" \ + "lw $a1, 8(%1) \n\t" \ + "lw $a2,12(%1) \n\t" \ + "lw $a3,16(%1) \n\t" \ + "lw $a4,20(%1) \n\t" \ + "lw $a5,24(%1) \n\t" \ + "lw $a6,28(%1) \n\t" \ + "lw $a7,32(%1) \n\t" \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $a0 \n\t" \ + "addiu $sp, $sp, 16 \n\t" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_11W(lval, orig, arg1,arg2,arg3,arg4,arg5, \ + arg6,arg7,arg8,arg9,arg10, \ + arg11) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[12]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + _argvec[11] = (unsigned long)(arg11); \ + __asm__ volatile( \ + "addiu $sp, $sp, -16 \n\t" \ + "lw $t9,36(%1) \n\t" \ + "sw $t9, 0($sp) \n\t" \ + "lw $t9,40(%1) \n\t" \ + "sw $t9, 4($sp) \n\t" \ + "lw $t9,44(%1) \n\t" \ + "sw $t9, 8($sp) \n\t" \ + "lw $t9, 0(%1) \n\t" \ + "lw $a0, 4(%1) \n\t" \ + "lw $a1, 8(%1) \n\t" \ + "lw $a2,12(%1) \n\t" \ + "lw $a3,16(%1) \n\t" \ + "lw $a4,20(%1) \n\t" \ + "lw $a5,24(%1) \n\t" \ + "lw $a6,28(%1) \n\t" \ + "lw $a7,32(%1) \n\t" \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $a0 \n\t" \ + "addiu $sp, $sp, 16 \n\t" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_12W(lval, orig, arg1,arg2,arg3,arg4,arg5, \ + arg6,arg7,arg8,arg9,arg10, \ + arg11,arg12) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[13]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + _argvec[11] = (unsigned long)(arg11); \ + _argvec[12] = (unsigned long)(arg12); \ + __asm__ volatile( \ + "addiu $sp, $sp, -16 \n\t" \ + "lw $t9,36(%1) \n\t" \ + "sw $t9, 0($sp) \n\t" \ + "lw $t9,40(%1) \n\t" \ + "sw $t9, 4($sp) \n\t" \ + "lw $t9,44(%1) \n\t" \ + "sw $t9, 8($sp) \n\t" \ + "lw $t9,48(%1) \n\t" \ + "sw $t9,12($sp) \n\t" \ + "lw $t9, 0(%1) \n\t" \ + "lw $a0, 4(%1) \n\t" \ + "lw $a1, 8(%1) \n\t" \ + "lw $a2,12(%1) \n\t" \ + "lw $a3,16(%1) \n\t" \ + "lw $a4,20(%1) \n\t" \ + "lw $a5,24(%1) \n\t" \ + "lw $a6,28(%1) \n\t" \ + "lw $a7,32(%1) \n\t" \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $a0 \n\t" \ + "addiu $sp, $sp, 16 \n\t" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#endif /* PLAT_nanomips_linux */ + +/* ------------------------- mips64-linux ------------------------- */ + +#if defined(PLAT_mips64_linux) + +/* These regs are trashed by the hidden call. */ +#define __CALLER_SAVED_REGS "$2", "$3", "$4", "$5", "$6", \ +"$7", "$8", "$9", "$10", "$11", "$12", "$13", "$14", "$15", "$24", \ +"$25", "$31" + +/* These CALL_FN_ macros assume that on mips64-linux, + sizeof(long long) == 8. */ + +#define MIPS64_LONG2REG_CAST(x) ((long long)(long)x) + +#define CALL_FN_W_v(lval, orig) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long long _argvec[1]; \ + volatile unsigned long long _res; \ + _argvec[0] = MIPS64_LONG2REG_CAST(_orig.nraddr); \ + __asm__ volatile( \ + "ld $25, 0(%1)\n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) (long)_res; \ + } while (0) + +#define CALL_FN_W_W(lval, orig, arg1) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long long _argvec[2]; \ + volatile unsigned long long _res; \ + _argvec[0] = MIPS64_LONG2REG_CAST(_orig.nraddr); \ + _argvec[1] = MIPS64_LONG2REG_CAST(arg1); \ + __asm__ volatile( \ + "ld $4, 8(%1)\n\t" /* arg1*/ \ + "ld $25, 0(%1)\n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) (long)_res; \ + } while (0) + +#define CALL_FN_W_WW(lval, orig, arg1,arg2) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long long _argvec[3]; \ + volatile unsigned long long _res; \ + _argvec[0] = _orig.nraddr; \ + _argvec[1] = MIPS64_LONG2REG_CAST(arg1); \ + _argvec[2] = MIPS64_LONG2REG_CAST(arg2); \ + __asm__ volatile( \ + "ld $4, 8(%1)\n\t" \ + "ld $5, 16(%1)\n\t" \ + "ld $25, 0(%1)\n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) (long)_res; \ + } while (0) + + +#define CALL_FN_W_WWW(lval, orig, arg1,arg2,arg3) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long long _argvec[4]; \ + volatile unsigned long long _res; \ + _argvec[0] = _orig.nraddr; \ + _argvec[1] = MIPS64_LONG2REG_CAST(arg1); \ + _argvec[2] = MIPS64_LONG2REG_CAST(arg2); \ + _argvec[3] = MIPS64_LONG2REG_CAST(arg3); \ + __asm__ volatile( \ + "ld $4, 8(%1)\n\t" \ + "ld $5, 16(%1)\n\t" \ + "ld $6, 24(%1)\n\t" \ + "ld $25, 0(%1)\n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) (long)_res; \ + } while (0) + +#define CALL_FN_W_WWWW(lval, orig, arg1,arg2,arg3,arg4) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long long _argvec[5]; \ + volatile unsigned long long _res; \ + _argvec[0] = MIPS64_LONG2REG_CAST(_orig.nraddr); \ + _argvec[1] = MIPS64_LONG2REG_CAST(arg1); \ + _argvec[2] = MIPS64_LONG2REG_CAST(arg2); \ + _argvec[3] = MIPS64_LONG2REG_CAST(arg3); \ + _argvec[4] = MIPS64_LONG2REG_CAST(arg4); \ + __asm__ volatile( \ + "ld $4, 8(%1)\n\t" \ + "ld $5, 16(%1)\n\t" \ + "ld $6, 24(%1)\n\t" \ + "ld $7, 32(%1)\n\t" \ + "ld $25, 0(%1)\n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) (long)_res; \ + } while (0) + +#define CALL_FN_W_5W(lval, orig, arg1,arg2,arg3,arg4,arg5) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long long _argvec[6]; \ + volatile unsigned long long _res; \ + _argvec[0] = MIPS64_LONG2REG_CAST(_orig.nraddr); \ + _argvec[1] = MIPS64_LONG2REG_CAST(arg1); \ + _argvec[2] = MIPS64_LONG2REG_CAST(arg2); \ + _argvec[3] = MIPS64_LONG2REG_CAST(arg3); \ + _argvec[4] = MIPS64_LONG2REG_CAST(arg4); \ + _argvec[5] = MIPS64_LONG2REG_CAST(arg5); \ + __asm__ volatile( \ + "ld $4, 8(%1)\n\t" \ + "ld $5, 16(%1)\n\t" \ + "ld $6, 24(%1)\n\t" \ + "ld $7, 32(%1)\n\t" \ + "ld $8, 40(%1)\n\t" \ + "ld $25, 0(%1)\n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) (long)_res; \ + } while (0) + +#define CALL_FN_W_6W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long long _argvec[7]; \ + volatile unsigned long long _res; \ + _argvec[0] = MIPS64_LONG2REG_CAST(_orig.nraddr); \ + _argvec[1] = MIPS64_LONG2REG_CAST(arg1); \ + _argvec[2] = MIPS64_LONG2REG_CAST(arg2); \ + _argvec[3] = MIPS64_LONG2REG_CAST(arg3); \ + _argvec[4] = MIPS64_LONG2REG_CAST(arg4); \ + _argvec[5] = MIPS64_LONG2REG_CAST(arg5); \ + _argvec[6] = MIPS64_LONG2REG_CAST(arg6); \ + __asm__ volatile( \ + "ld $4, 8(%1)\n\t" \ + "ld $5, 16(%1)\n\t" \ + "ld $6, 24(%1)\n\t" \ + "ld $7, 32(%1)\n\t" \ + "ld $8, 40(%1)\n\t" \ + "ld $9, 48(%1)\n\t" \ + "ld $25, 0(%1)\n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) (long)_res; \ + } while (0) + +#define CALL_FN_W_7W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long long _argvec[8]; \ + volatile unsigned long long _res; \ + _argvec[0] = MIPS64_LONG2REG_CAST(_orig.nraddr); \ + _argvec[1] = MIPS64_LONG2REG_CAST(arg1); \ + _argvec[2] = MIPS64_LONG2REG_CAST(arg2); \ + _argvec[3] = MIPS64_LONG2REG_CAST(arg3); \ + _argvec[4] = MIPS64_LONG2REG_CAST(arg4); \ + _argvec[5] = MIPS64_LONG2REG_CAST(arg5); \ + _argvec[6] = MIPS64_LONG2REG_CAST(arg6); \ + _argvec[7] = MIPS64_LONG2REG_CAST(arg7); \ + __asm__ volatile( \ + "ld $4, 8(%1)\n\t" \ + "ld $5, 16(%1)\n\t" \ + "ld $6, 24(%1)\n\t" \ + "ld $7, 32(%1)\n\t" \ + "ld $8, 40(%1)\n\t" \ + "ld $9, 48(%1)\n\t" \ + "ld $10, 56(%1)\n\t" \ + "ld $25, 0(%1) \n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) (long)_res; \ + } while (0) + +#define CALL_FN_W_8W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long long _argvec[9]; \ + volatile unsigned long long _res; \ + _argvec[0] = MIPS64_LONG2REG_CAST(_orig.nraddr); \ + _argvec[1] = MIPS64_LONG2REG_CAST(arg1); \ + _argvec[2] = MIPS64_LONG2REG_CAST(arg2); \ + _argvec[3] = MIPS64_LONG2REG_CAST(arg3); \ + _argvec[4] = MIPS64_LONG2REG_CAST(arg4); \ + _argvec[5] = MIPS64_LONG2REG_CAST(arg5); \ + _argvec[6] = MIPS64_LONG2REG_CAST(arg6); \ + _argvec[7] = MIPS64_LONG2REG_CAST(arg7); \ + _argvec[8] = MIPS64_LONG2REG_CAST(arg8); \ + __asm__ volatile( \ + "ld $4, 8(%1)\n\t" \ + "ld $5, 16(%1)\n\t" \ + "ld $6, 24(%1)\n\t" \ + "ld $7, 32(%1)\n\t" \ + "ld $8, 40(%1)\n\t" \ + "ld $9, 48(%1)\n\t" \ + "ld $10, 56(%1)\n\t" \ + "ld $11, 64(%1)\n\t" \ + "ld $25, 0(%1) \n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) (long)_res; \ + } while (0) + +#define CALL_FN_W_9W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long long _argvec[10]; \ + volatile unsigned long long _res; \ + _argvec[0] = MIPS64_LONG2REG_CAST(_orig.nraddr); \ + _argvec[1] = MIPS64_LONG2REG_CAST(arg1); \ + _argvec[2] = MIPS64_LONG2REG_CAST(arg2); \ + _argvec[3] = MIPS64_LONG2REG_CAST(arg3); \ + _argvec[4] = MIPS64_LONG2REG_CAST(arg4); \ + _argvec[5] = MIPS64_LONG2REG_CAST(arg5); \ + _argvec[6] = MIPS64_LONG2REG_CAST(arg6); \ + _argvec[7] = MIPS64_LONG2REG_CAST(arg7); \ + _argvec[8] = MIPS64_LONG2REG_CAST(arg8); \ + _argvec[9] = MIPS64_LONG2REG_CAST(arg9); \ + __asm__ volatile( \ + "dsubu $29, $29, 8\n\t" \ + "ld $4, 72(%1)\n\t" \ + "sd $4, 0($29)\n\t" \ + "ld $4, 8(%1)\n\t" \ + "ld $5, 16(%1)\n\t" \ + "ld $6, 24(%1)\n\t" \ + "ld $7, 32(%1)\n\t" \ + "ld $8, 40(%1)\n\t" \ + "ld $9, 48(%1)\n\t" \ + "ld $10, 56(%1)\n\t" \ + "ld $11, 64(%1)\n\t" \ + "ld $25, 0(%1)\n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "daddu $29, $29, 8\n\t" \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) (long)_res; \ + } while (0) + +#define CALL_FN_W_10W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long long _argvec[11]; \ + volatile unsigned long long _res; \ + _argvec[0] = MIPS64_LONG2REG_CAST(_orig.nraddr); \ + _argvec[1] = MIPS64_LONG2REG_CAST(arg1); \ + _argvec[2] = MIPS64_LONG2REG_CAST(arg2); \ + _argvec[3] = MIPS64_LONG2REG_CAST(arg3); \ + _argvec[4] = MIPS64_LONG2REG_CAST(arg4); \ + _argvec[5] = MIPS64_LONG2REG_CAST(arg5); \ + _argvec[6] = MIPS64_LONG2REG_CAST(arg6); \ + _argvec[7] = MIPS64_LONG2REG_CAST(arg7); \ + _argvec[8] = MIPS64_LONG2REG_CAST(arg8); \ + _argvec[9] = MIPS64_LONG2REG_CAST(arg9); \ + _argvec[10] = MIPS64_LONG2REG_CAST(arg10); \ + __asm__ volatile( \ + "dsubu $29, $29, 16\n\t" \ + "ld $4, 72(%1)\n\t" \ + "sd $4, 0($29)\n\t" \ + "ld $4, 80(%1)\n\t" \ + "sd $4, 8($29)\n\t" \ + "ld $4, 8(%1)\n\t" \ + "ld $5, 16(%1)\n\t" \ + "ld $6, 24(%1)\n\t" \ + "ld $7, 32(%1)\n\t" \ + "ld $8, 40(%1)\n\t" \ + "ld $9, 48(%1)\n\t" \ + "ld $10, 56(%1)\n\t" \ + "ld $11, 64(%1)\n\t" \ + "ld $25, 0(%1)\n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "daddu $29, $29, 16\n\t" \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) (long)_res; \ + } while (0) + +#define CALL_FN_W_11W(lval, orig, arg1,arg2,arg3,arg4,arg5, \ + arg6,arg7,arg8,arg9,arg10, \ + arg11) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long long _argvec[12]; \ + volatile unsigned long long _res; \ + _argvec[0] = MIPS64_LONG2REG_CAST(_orig.nraddr); \ + _argvec[1] = MIPS64_LONG2REG_CAST(arg1); \ + _argvec[2] = MIPS64_LONG2REG_CAST(arg2); \ + _argvec[3] = MIPS64_LONG2REG_CAST(arg3); \ + _argvec[4] = MIPS64_LONG2REG_CAST(arg4); \ + _argvec[5] = MIPS64_LONG2REG_CAST(arg5); \ + _argvec[6] = MIPS64_LONG2REG_CAST(arg6); \ + _argvec[7] = MIPS64_LONG2REG_CAST(arg7); \ + _argvec[8] = MIPS64_LONG2REG_CAST(arg8); \ + _argvec[9] = MIPS64_LONG2REG_CAST(arg9); \ + _argvec[10] = MIPS64_LONG2REG_CAST(arg10); \ + _argvec[11] = MIPS64_LONG2REG_CAST(arg11); \ + __asm__ volatile( \ + "dsubu $29, $29, 24\n\t" \ + "ld $4, 72(%1)\n\t" \ + "sd $4, 0($29)\n\t" \ + "ld $4, 80(%1)\n\t" \ + "sd $4, 8($29)\n\t" \ + "ld $4, 88(%1)\n\t" \ + "sd $4, 16($29)\n\t" \ + "ld $4, 8(%1)\n\t" \ + "ld $5, 16(%1)\n\t" \ + "ld $6, 24(%1)\n\t" \ + "ld $7, 32(%1)\n\t" \ + "ld $8, 40(%1)\n\t" \ + "ld $9, 48(%1)\n\t" \ + "ld $10, 56(%1)\n\t" \ + "ld $11, 64(%1)\n\t" \ + "ld $25, 0(%1)\n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "daddu $29, $29, 24\n\t" \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) (long)_res; \ + } while (0) + +#define CALL_FN_W_12W(lval, orig, arg1,arg2,arg3,arg4,arg5, \ + arg6,arg7,arg8,arg9,arg10, \ + arg11,arg12) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long long _argvec[13]; \ + volatile unsigned long long _res; \ + _argvec[0] = MIPS64_LONG2REG_CAST(_orig.nraddr); \ + _argvec[1] = MIPS64_LONG2REG_CAST(arg1); \ + _argvec[2] = MIPS64_LONG2REG_CAST(arg2); \ + _argvec[3] = MIPS64_LONG2REG_CAST(arg3); \ + _argvec[4] = MIPS64_LONG2REG_CAST(arg4); \ + _argvec[5] = MIPS64_LONG2REG_CAST(arg5); \ + _argvec[6] = MIPS64_LONG2REG_CAST(arg6); \ + _argvec[7] = MIPS64_LONG2REG_CAST(arg7); \ + _argvec[8] = MIPS64_LONG2REG_CAST(arg8); \ + _argvec[9] = MIPS64_LONG2REG_CAST(arg9); \ + _argvec[10] = MIPS64_LONG2REG_CAST(arg10); \ + _argvec[11] = MIPS64_LONG2REG_CAST(arg11); \ + _argvec[12] = MIPS64_LONG2REG_CAST(arg12); \ + __asm__ volatile( \ + "dsubu $29, $29, 32\n\t" \ + "ld $4, 72(%1)\n\t" \ + "sd $4, 0($29)\n\t" \ + "ld $4, 80(%1)\n\t" \ + "sd $4, 8($29)\n\t" \ + "ld $4, 88(%1)\n\t" \ + "sd $4, 16($29)\n\t" \ + "ld $4, 96(%1)\n\t" \ + "sd $4, 24($29)\n\t" \ + "ld $4, 8(%1)\n\t" \ + "ld $5, 16(%1)\n\t" \ + "ld $6, 24(%1)\n\t" \ + "ld $7, 32(%1)\n\t" \ + "ld $8, 40(%1)\n\t" \ + "ld $9, 48(%1)\n\t" \ + "ld $10, 56(%1)\n\t" \ + "ld $11, 64(%1)\n\t" \ + "ld $25, 0(%1)\n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "daddu $29, $29, 32\n\t" \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) (long)_res; \ + } while (0) + +#endif /* PLAT_mips64_linux */ + +/* ------------------------------------------------------------------ */ +/* ARCHITECTURE INDEPENDENT MACROS for CLIENT REQUESTS. */ +/* */ +/* ------------------------------------------------------------------ */ + +/* Some request codes. There are many more of these, but most are not + exposed to end-user view. These are the public ones, all of the + form 0x1000 + small_number. + + Core ones are in the range 0x00000000--0x0000ffff. The non-public + ones start at 0x2000. +*/ + +/* These macros are used by tools -- they must be public, but don't + embed them into other programs. */ +#define VG_USERREQ_TOOL_BASE(a,b) \ + ((unsigned int)(((a)&0xff) << 24 | ((b)&0xff) << 16)) +#define VG_IS_TOOL_USERREQ(a, b, v) \ + (VG_USERREQ_TOOL_BASE(a,b) == ((v) & 0xffff0000)) + +/* !! ABIWARNING !! ABIWARNING !! ABIWARNING !! ABIWARNING !! + This enum comprises an ABI exported by Valgrind to programs + which use client requests. DO NOT CHANGE THE NUMERIC VALUES OF THESE + ENTRIES, NOR DELETE ANY -- add new ones at the end of the most + relevant group. */ +typedef + enum { VG_USERREQ__RUNNING_ON_VALGRIND = 0x1001, + VG_USERREQ__DISCARD_TRANSLATIONS = 0x1002, + + /* These allow any function to be called from the simulated + CPU but run on the real CPU. Nb: the first arg passed to + the function is always the ThreadId of the running + thread! So CLIENT_CALL0 actually requires a 1 arg + function, etc. */ + VG_USERREQ__CLIENT_CALL0 = 0x1101, + VG_USERREQ__CLIENT_CALL1 = 0x1102, + VG_USERREQ__CLIENT_CALL2 = 0x1103, + VG_USERREQ__CLIENT_CALL3 = 0x1104, + + /* Can be useful in regression testing suites -- eg. can + send Valgrind's output to /dev/null and still count + errors. */ + VG_USERREQ__COUNT_ERRORS = 0x1201, + + /* Allows the client program and/or gdbserver to execute a monitor + command. */ + VG_USERREQ__GDB_MONITOR_COMMAND = 0x1202, + + /* Allows the client program to change a dynamic command line + option. */ + VG_USERREQ__CLO_CHANGE = 0x1203, + + /* These are useful and can be interpreted by any tool that + tracks malloc() et al, by using vg_replace_malloc.c. */ + VG_USERREQ__MALLOCLIKE_BLOCK = 0x1301, + VG_USERREQ__RESIZEINPLACE_BLOCK = 0x130b, + VG_USERREQ__FREELIKE_BLOCK = 0x1302, + /* Memory pool support. */ + VG_USERREQ__CREATE_MEMPOOL = 0x1303, + VG_USERREQ__DESTROY_MEMPOOL = 0x1304, + VG_USERREQ__MEMPOOL_ALLOC = 0x1305, + VG_USERREQ__MEMPOOL_FREE = 0x1306, + VG_USERREQ__MEMPOOL_TRIM = 0x1307, + VG_USERREQ__MOVE_MEMPOOL = 0x1308, + VG_USERREQ__MEMPOOL_CHANGE = 0x1309, + VG_USERREQ__MEMPOOL_EXISTS = 0x130a, + + /* Allow printfs to valgrind log. */ + /* The first two pass the va_list argument by value, which + assumes it is the same size as or smaller than a UWord, + which generally isn't the case. Hence are deprecated. + The second two pass the vargs by reference and so are + immune to this problem. */ + /* both :: char* fmt, va_list vargs (DEPRECATED) */ + VG_USERREQ__PRINTF = 0x1401, + VG_USERREQ__PRINTF_BACKTRACE = 0x1402, + /* both :: char* fmt, va_list* vargs */ + VG_USERREQ__PRINTF_VALIST_BY_REF = 0x1403, + VG_USERREQ__PRINTF_BACKTRACE_VALIST_BY_REF = 0x1404, + + /* Stack support. */ + VG_USERREQ__STACK_REGISTER = 0x1501, + VG_USERREQ__STACK_DEREGISTER = 0x1502, + VG_USERREQ__STACK_CHANGE = 0x1503, + + /* Wine support */ + VG_USERREQ__LOAD_PDB_DEBUGINFO = 0x1601, + + /* Querying of debug info. */ + VG_USERREQ__MAP_IP_TO_SRCLOC = 0x1701, + + /* Disable/enable error reporting level. Takes a single + Word arg which is the delta to this thread's error + disablement indicator. Hence 1 disables or further + disables errors, and -1 moves back towards enablement. + Other values are not allowed. */ + VG_USERREQ__CHANGE_ERR_DISABLEMENT = 0x1801, + + /* Some requests used for Valgrind internal, such as + self-test or self-hosting. */ + /* Initialise IR injection */ + VG_USERREQ__VEX_INIT_FOR_IRI = 0x1901, + /* Used by Inner Valgrind to inform Outer Valgrind where to + find the list of inner guest threads */ + VG_USERREQ__INNER_THREADS = 0x1902 + } Vg_ClientRequest; + +#if !defined(__GNUC__) +# define __extension__ /* */ +#endif + + +/* Returns the number of Valgrinds this code is running under. That + is, 0 if running natively, 1 if running under Valgrind, 2 if + running under Valgrind which is running under another Valgrind, + etc. */ +#define RUNNING_ON_VALGRIND \ + (unsigned)VALGRIND_DO_CLIENT_REQUEST_EXPR(0 /* if not */, \ + VG_USERREQ__RUNNING_ON_VALGRIND, \ + 0, 0, 0, 0, 0) \ + + +/* Discard translation of code in the range [_qzz_addr .. _qzz_addr + + _qzz_len - 1]. Useful if you are debugging a JITter or some such, + since it provides a way to make sure valgrind will retranslate the + invalidated area. Returns no value. */ +#define VALGRIND_DISCARD_TRANSLATIONS(_qzz_addr,_qzz_len) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__DISCARD_TRANSLATIONS, \ + _qzz_addr, _qzz_len, 0, 0, 0) + +#define VALGRIND_INNER_THREADS(_qzz_addr) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__INNER_THREADS, \ + _qzz_addr, 0, 0, 0, 0) + + +/* These requests are for getting Valgrind itself to print something. + Possibly with a backtrace. This is a really ugly hack. The return value + is the number of characters printed, excluding the "**** " part at the + start and the backtrace (if present). */ + +#if defined(__GNUC__) || defined(__INTEL_COMPILER) && !defined(_MSC_VER) +/* Modern GCC will optimize the static routine out if unused, + and unused attribute will shut down warnings about it. */ +static int VALGRIND_PRINTF(const char *format, ...) + __attribute__((format(__printf__, 1, 2), __unused__)); +#endif +static int +#if defined(_MSC_VER) +__inline +#endif +VALGRIND_PRINTF(const char *format, ...) +{ +#if defined(NVALGRIND) + (void)format; + return 0; +#else /* NVALGRIND */ +#if defined(_MSC_VER) || defined(__MINGW64__) + uintptr_t _qzz_res; +#else + unsigned long _qzz_res; +#endif + va_list vargs; + va_start(vargs, format); +#if defined(_MSC_VER) || defined(__MINGW64__) + _qzz_res = VALGRIND_DO_CLIENT_REQUEST_EXPR(0, + VG_USERREQ__PRINTF_VALIST_BY_REF, + (uintptr_t)format, + (uintptr_t)&vargs, + 0, 0, 0); +#else + _qzz_res = VALGRIND_DO_CLIENT_REQUEST_EXPR(0, + VG_USERREQ__PRINTF_VALIST_BY_REF, + (unsigned long)format, + (unsigned long)&vargs, + 0, 0, 0); +#endif + va_end(vargs); + return (int)_qzz_res; +#endif /* NVALGRIND */ +} + +#if defined(__GNUC__) || defined(__INTEL_COMPILER) && !defined(_MSC_VER) +static int VALGRIND_PRINTF_BACKTRACE(const char *format, ...) + __attribute__((format(__printf__, 1, 2), __unused__)); +#endif +static int +#if defined(_MSC_VER) +__inline +#endif +VALGRIND_PRINTF_BACKTRACE(const char *format, ...) +{ +#if defined(NVALGRIND) + (void)format; + return 0; +#else /* NVALGRIND */ +#if defined(_MSC_VER) || defined(__MINGW64__) + uintptr_t _qzz_res; +#else + unsigned long _qzz_res; +#endif + va_list vargs; + va_start(vargs, format); +#if defined(_MSC_VER) || defined(__MINGW64__) + _qzz_res = VALGRIND_DO_CLIENT_REQUEST_EXPR(0, + VG_USERREQ__PRINTF_BACKTRACE_VALIST_BY_REF, + (uintptr_t)format, + (uintptr_t)&vargs, + 0, 0, 0); +#else + _qzz_res = VALGRIND_DO_CLIENT_REQUEST_EXPR(0, + VG_USERREQ__PRINTF_BACKTRACE_VALIST_BY_REF, + (unsigned long)format, + (unsigned long)&vargs, + 0, 0, 0); +#endif + va_end(vargs); + return (int)_qzz_res; +#endif /* NVALGRIND */ +} + + +/* These requests allow control to move from the simulated CPU to the + real CPU, calling an arbitrary function. + + Note that the current ThreadId is inserted as the first argument. + So this call: + + VALGRIND_NON_SIMD_CALL2(f, arg1, arg2) + + requires f to have this signature: + + Word f(Word tid, Word arg1, Word arg2) + + where "Word" is a word-sized type. + + Note that these client requests are not entirely reliable. For example, + if you call a function with them that subsequently calls printf(), + there's a high chance Valgrind will crash. Generally, your prospects of + these working are made higher if the called function does not refer to + any global variables, and does not refer to any libc or other functions + (printf et al). Any kind of entanglement with libc or dynamic linking is + likely to have a bad outcome, for tricky reasons which we've grappled + with a lot in the past. +*/ +#define VALGRIND_NON_SIMD_CALL0(_qyy_fn) \ + VALGRIND_DO_CLIENT_REQUEST_EXPR(0 /* default return */, \ + VG_USERREQ__CLIENT_CALL0, \ + _qyy_fn, \ + 0, 0, 0, 0) + +#define VALGRIND_NON_SIMD_CALL1(_qyy_fn, _qyy_arg1) \ + VALGRIND_DO_CLIENT_REQUEST_EXPR(0 /* default return */, \ + VG_USERREQ__CLIENT_CALL1, \ + _qyy_fn, \ + _qyy_arg1, 0, 0, 0) + +#define VALGRIND_NON_SIMD_CALL2(_qyy_fn, _qyy_arg1, _qyy_arg2) \ + VALGRIND_DO_CLIENT_REQUEST_EXPR(0 /* default return */, \ + VG_USERREQ__CLIENT_CALL2, \ + _qyy_fn, \ + _qyy_arg1, _qyy_arg2, 0, 0) + +#define VALGRIND_NON_SIMD_CALL3(_qyy_fn, _qyy_arg1, _qyy_arg2, _qyy_arg3) \ + VALGRIND_DO_CLIENT_REQUEST_EXPR(0 /* default return */, \ + VG_USERREQ__CLIENT_CALL3, \ + _qyy_fn, \ + _qyy_arg1, _qyy_arg2, \ + _qyy_arg3, 0) + + +/* Counts the number of errors that have been recorded by a tool. Nb: + the tool must record the errors with VG_(maybe_record_error)() or + VG_(unique_error)() for them to be counted. */ +#define VALGRIND_COUNT_ERRORS \ + (unsigned)VALGRIND_DO_CLIENT_REQUEST_EXPR( \ + 0 /* default return */, \ + VG_USERREQ__COUNT_ERRORS, \ + 0, 0, 0, 0, 0) + +/* Several Valgrind tools (Memcheck, Massif, Helgrind, DRD) rely on knowing + when heap blocks are allocated in order to give accurate results. This + happens automatically for the standard allocator functions such as + malloc(), calloc(), realloc(), memalign(), new, new[], free(), delete, + delete[], etc. + + But if your program uses a custom allocator, this doesn't automatically + happen, and Valgrind will not do as well. For example, if you allocate + superblocks with mmap() and then allocates chunks of the superblocks, all + Valgrind's observations will be at the mmap() level and it won't know that + the chunks should be considered separate entities. In Memcheck's case, + that means you probably won't get heap block overrun detection (because + there won't be redzones marked as unaddressable) and you definitely won't + get any leak detection. + + The following client requests allow a custom allocator to be annotated so + that it can be handled accurately by Valgrind. + + VALGRIND_MALLOCLIKE_BLOCK marks a region of memory as having been allocated + by a malloc()-like function. For Memcheck (an illustrative case), this + does two things: + + - It records that the block has been allocated. This means any addresses + within the block mentioned in error messages will be + identified as belonging to the block. It also means that if the block + isn't freed it will be detected by the leak checker. + + - It marks the block as being addressable and undefined (if 'is_zeroed' is + not set), or addressable and defined (if 'is_zeroed' is set). This + controls how accesses to the block by the program are handled. + + 'addr' is the start of the usable block (ie. after any + redzone), 'sizeB' is its size. 'rzB' is the redzone size if the allocator + can apply redzones -- these are blocks of padding at the start and end of + each block. Adding redzones is recommended as it makes it much more likely + Valgrind will spot block overruns. `is_zeroed' indicates if the memory is + zeroed (or filled with another predictable value), as is the case for + calloc(). + + VALGRIND_MALLOCLIKE_BLOCK should be put immediately after the point where a + heap block -- that will be used by the client program -- is allocated. + It's best to put it at the outermost level of the allocator if possible; + for example, if you have a function my_alloc() which calls + internal_alloc(), and the client request is put inside internal_alloc(), + stack traces relating to the heap block will contain entries for both + my_alloc() and internal_alloc(), which is probably not what you want. + + For Memcheck users: if you use VALGRIND_MALLOCLIKE_BLOCK to carve out + custom blocks from within a heap block, B, that has been allocated with + malloc/calloc/new/etc, then block B will be *ignored* during leak-checking + -- the custom blocks will take precedence. + + VALGRIND_FREELIKE_BLOCK is the partner to VALGRIND_MALLOCLIKE_BLOCK. For + Memcheck, it does two things: + + - It records that the block has been deallocated. This assumes that the + block was annotated as having been allocated via + VALGRIND_MALLOCLIKE_BLOCK. Otherwise, an error will be issued. + + - It marks the block as being unaddressable. + + VALGRIND_FREELIKE_BLOCK should be put immediately after the point where a + heap block is deallocated. + + VALGRIND_RESIZEINPLACE_BLOCK informs a tool about reallocation. For + Memcheck, it does four things: + + - It records that the size of a block has been changed. This assumes that + the block was annotated as having been allocated via + VALGRIND_MALLOCLIKE_BLOCK. Otherwise, an error will be issued. + + - If the block shrunk, it marks the freed memory as being unaddressable. + + - If the block grew, it marks the new area as undefined and defines a red + zone past the end of the new block. + + - The V-bits of the overlap between the old and the new block are preserved. + + VALGRIND_RESIZEINPLACE_BLOCK should be put after allocation of the new block + and before deallocation of the old block. + + In many cases, these three client requests will not be enough to get your + allocator working well with Memcheck. More specifically, if your allocator + writes to freed blocks in any way then a VALGRIND_MAKE_MEM_UNDEFINED call + will be necessary to mark the memory as addressable just before the zeroing + occurs, otherwise you'll get a lot of invalid write errors. For example, + you'll need to do this if your allocator recycles freed blocks, but it + zeroes them before handing them back out (via VALGRIND_MALLOCLIKE_BLOCK). + Alternatively, if your allocator reuses freed blocks for allocator-internal + data structures, VALGRIND_MAKE_MEM_UNDEFINED calls will also be necessary. + + Really, what's happening is a blurring of the lines between the client + program and the allocator... after VALGRIND_FREELIKE_BLOCK is called, the + memory should be considered unaddressable to the client program, but the + allocator knows more than the rest of the client program and so may be able + to safely access it. Extra client requests are necessary for Valgrind to + understand the distinction between the allocator and the rest of the + program. + + Ignored if addr == 0. +*/ +#define VALGRIND_MALLOCLIKE_BLOCK(addr, sizeB, rzB, is_zeroed) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__MALLOCLIKE_BLOCK, \ + addr, sizeB, rzB, is_zeroed, 0) + +/* See the comment for VALGRIND_MALLOCLIKE_BLOCK for details. + Ignored if addr == 0. +*/ +#define VALGRIND_RESIZEINPLACE_BLOCK(addr, oldSizeB, newSizeB, rzB) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__RESIZEINPLACE_BLOCK, \ + addr, oldSizeB, newSizeB, rzB, 0) + +/* See the comment for VALGRIND_MALLOCLIKE_BLOCK for details. + Ignored if addr == 0. +*/ +#define VALGRIND_FREELIKE_BLOCK(addr, rzB) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__FREELIKE_BLOCK, \ + addr, rzB, 0, 0, 0) + +/* Create a memory pool. */ +#define VALGRIND_CREATE_MEMPOOL(pool, rzB, is_zeroed) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__CREATE_MEMPOOL, \ + pool, rzB, is_zeroed, 0, 0) + +/* Create a memory pool with some flags specifying extended behaviour. + When flags is zero, the behaviour is identical to VALGRIND_CREATE_MEMPOOL. + + The flag VALGRIND_MEMPOOL_METAPOOL specifies that the pieces of memory + associated with the pool using VALGRIND_MEMPOOL_ALLOC will be used + by the application as superblocks to dole out MALLOC_LIKE blocks using + VALGRIND_MALLOCLIKE_BLOCK. In other words, a meta pool is a "2 levels" + pool : first level is the blocks described by VALGRIND_MEMPOOL_ALLOC. + The second level blocks are described using VALGRIND_MALLOCLIKE_BLOCK. + Note that the association between the pool and the second level blocks + is implicit : second level blocks will be located inside first level + blocks. It is necessary to use the VALGRIND_MEMPOOL_METAPOOL flag + for such 2 levels pools, as otherwise valgrind will detect overlapping + memory blocks, and will abort execution (e.g. during leak search). + + Such a meta pool can also be marked as an 'auto free' pool using the flag + VALGRIND_MEMPOOL_AUTO_FREE, which must be OR-ed together with the + VALGRIND_MEMPOOL_METAPOOL. For an 'auto free' pool, VALGRIND_MEMPOOL_FREE + will automatically free the second level blocks that are contained + inside the first level block freed with VALGRIND_MEMPOOL_FREE. + In other words, calling VALGRIND_MEMPOOL_FREE will cause implicit calls + to VALGRIND_FREELIKE_BLOCK for all the second level blocks included + in the first level block. + Note: it is an error to use the VALGRIND_MEMPOOL_AUTO_FREE flag + without the VALGRIND_MEMPOOL_METAPOOL flag. +*/ +#define VALGRIND_MEMPOOL_AUTO_FREE 1 +#define VALGRIND_MEMPOOL_METAPOOL 2 +#define VALGRIND_CREATE_MEMPOOL_EXT(pool, rzB, is_zeroed, flags) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__CREATE_MEMPOOL, \ + pool, rzB, is_zeroed, flags, 0) + +/* Destroy a memory pool. */ +#define VALGRIND_DESTROY_MEMPOOL(pool) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__DESTROY_MEMPOOL, \ + pool, 0, 0, 0, 0) + +/* Associate a piece of memory with a memory pool. */ +#define VALGRIND_MEMPOOL_ALLOC(pool, addr, size) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__MEMPOOL_ALLOC, \ + pool, addr, size, 0, 0) + +/* Disassociate a piece of memory from a memory pool. */ +#define VALGRIND_MEMPOOL_FREE(pool, addr) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__MEMPOOL_FREE, \ + pool, addr, 0, 0, 0) + +/* Disassociate any pieces outside a particular range. */ +#define VALGRIND_MEMPOOL_TRIM(pool, addr, size) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__MEMPOOL_TRIM, \ + pool, addr, size, 0, 0) + +/* Resize and/or move a piece associated with a memory pool. */ +#define VALGRIND_MOVE_MEMPOOL(poolA, poolB) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__MOVE_MEMPOOL, \ + poolA, poolB, 0, 0, 0) + +/* Resize and/or move a piece associated with a memory pool. */ +#define VALGRIND_MEMPOOL_CHANGE(pool, addrA, addrB, size) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__MEMPOOL_CHANGE, \ + pool, addrA, addrB, size, 0) + +/* Return 1 if a mempool exists, else 0. */ +#define VALGRIND_MEMPOOL_EXISTS(pool) \ + (unsigned)VALGRIND_DO_CLIENT_REQUEST_EXPR(0, \ + VG_USERREQ__MEMPOOL_EXISTS, \ + pool, 0, 0, 0, 0) + +/* Mark a piece of memory as being a stack. Returns a stack id. + start is the lowest addressable stack byte, end is the highest + addressable stack byte. */ +#define VALGRIND_STACK_REGISTER(start, end) \ + (unsigned)VALGRIND_DO_CLIENT_REQUEST_EXPR(0, \ + VG_USERREQ__STACK_REGISTER, \ + start, end, 0, 0, 0) + +/* Unmark the piece of memory associated with a stack id as being a + stack. */ +#define VALGRIND_STACK_DEREGISTER(id) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__STACK_DEREGISTER, \ + id, 0, 0, 0, 0) + +/* Change the start and end address of the stack id. + start is the new lowest addressable stack byte, end is the new highest + addressable stack byte. */ +#define VALGRIND_STACK_CHANGE(id, start, end) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__STACK_CHANGE, \ + id, start, end, 0, 0) + +/* Load PDB debug info for Wine PE image_map. */ +#define VALGRIND_LOAD_PDB_DEBUGINFO(fd, ptr, total_size, delta) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__LOAD_PDB_DEBUGINFO, \ + fd, ptr, total_size, delta, 0) + +/* Map a code address to a source file name and line number. buf64 + must point to a 64-byte buffer in the caller's address space. The + result will be dumped in there and is guaranteed to be zero + terminated. If no info is found, the first byte is set to zero. */ +#define VALGRIND_MAP_IP_TO_SRCLOC(addr, buf64) \ + (unsigned)VALGRIND_DO_CLIENT_REQUEST_EXPR(0, \ + VG_USERREQ__MAP_IP_TO_SRCLOC, \ + addr, buf64, 0, 0, 0) + +/* Disable error reporting for this thread. Behaves in a stack like + way, so you can safely call this multiple times provided that + VALGRIND_ENABLE_ERROR_REPORTING is called the same number of times + to re-enable reporting. The first call of this macro disables + reporting. Subsequent calls have no effect except to increase the + number of VALGRIND_ENABLE_ERROR_REPORTING calls needed to re-enable + reporting. Child threads do not inherit this setting from their + parents -- they are always created with reporting enabled. */ +#define VALGRIND_DISABLE_ERROR_REPORTING \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__CHANGE_ERR_DISABLEMENT, \ + 1, 0, 0, 0, 0) + +/* Re-enable error reporting, as per comments on + VALGRIND_DISABLE_ERROR_REPORTING. */ +#define VALGRIND_ENABLE_ERROR_REPORTING \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__CHANGE_ERR_DISABLEMENT, \ + -1, 0, 0, 0, 0) + +/* Execute a monitor command from the client program. + If a connection is opened with GDB, the output will be sent + according to the output mode set for vgdb. + If no connection is opened, output will go to the log output. + Returns 1 if command not recognised, 0 otherwise. */ +#define VALGRIND_MONITOR_COMMAND(command) \ + VALGRIND_DO_CLIENT_REQUEST_EXPR(0, VG_USERREQ__GDB_MONITOR_COMMAND, \ + command, 0, 0, 0, 0) + + +/* Change the value of a dynamic command line option. + Note that unknown or not dynamically changeable options + will cause a warning message to be output. */ +#define VALGRIND_CLO_CHANGE(option) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__CLO_CHANGE, \ + option, 0, 0, 0, 0) + + +#undef PLAT_x86_darwin +#undef PLAT_amd64_darwin +#undef PLAT_x86_win32 +#undef PLAT_amd64_win64 +#undef PLAT_x86_linux +#undef PLAT_amd64_linux +#undef PLAT_ppc32_linux +#undef PLAT_ppc64be_linux +#undef PLAT_ppc64le_linux +#undef PLAT_arm_linux +#undef PLAT_s390x_linux +#undef PLAT_mips32_linux +#undef PLAT_mips64_linux +#undef PLAT_nanomips_linux +#undef PLAT_x86_solaris +#undef PLAT_amd64_solaris + +#endif /* __VALGRIND_H */ diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index f64025c34683..ba66d504c4b7 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -66,8 +66,7 @@ set(TORCH_PYTHON_INCLUDE_DIRECTORIES ${CMAKE_BINARY_DIR}/third_party ${CMAKE_BINARY_DIR}/third_party/onnx - ${TORCH_ROOT}/third_party/valgrind/callgrind - ${TORCH_ROOT}/third_party/valgrind/include + ${TORCH_ROOT}/third_party/valgrind-headers ${TORCH_ROOT}/third_party/gloo ${TORCH_ROOT}/third_party/onnx From 4fdba305003006a0666efb70ce75cebead572f1e Mon Sep 17 00:00:00 2001 From: Meghan Lele Date: Tue, 6 Oct 2020 18:00:27 -0700 Subject: [PATCH 16/69] [JIT] Add API for ignoring arbitrary module attributes (#45262) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45262 **Summary** This commit adds an API for ignoring arbitrary module attributes during scripting. A class attribute named `ignored_attributes` containing names of attributes to ignore can be added to the class of the instance being scripted. Attributes ignored in this fashion cannot be used in `forward`, methods used by `forward` or by `exported` methods. They are, however, copied to the `RecursiveScriptModule` wrapper and can be used by `ignored` methods and regular Python code. **Test Plan** This commit adds unit tests to `TestScriptPy3` to test this new API. Test Plan: Imported from OSS Reviewed By: eellison Differential Revision: D23971882 Pulled By: SplitInfinity fbshipit-source-id: 8c81fb415fde7b78aa2f87e5d83a477e876a7cc3 --- test/jit/test_type_sharing.py | 51 +++++++++++++++++++ test/test_jit_py3.py | 50 ++++++++++++++++++ torch/_C/__init__.pyi.in | 2 + .../jit/frontend/concrete_module_type.cpp | 9 ++++ .../csrc/jit/frontend/concrete_module_type.h | 5 ++ .../csrc/jit/python/python_sugared_value.cpp | 2 + torch/csrc/jit/python/script_init.cpp | 12 +++++ torch/jit/_recursive.py | 20 +++++++- 8 files changed, 150 insertions(+), 1 deletion(-) diff --git a/test/jit/test_type_sharing.py b/test/jit/test_type_sharing.py index 7981ed96d510..cb6677937b96 100644 --- a/test/jit/test_type_sharing.py +++ b/test/jit/test_type_sharing.py @@ -560,3 +560,54 @@ def forward(self, x): self.assertDifferentType(top1_s, top2_s.sub) self.assertDifferentType(top2_s, top2_s.sub) self.assertDifferentType(top2_s, top1_s.sub) + + def test_type_shared_ignored_attributes(self): + """ + Test that types are shared if the exclusion of their + ignored attributes makes them equal. + """ + class A(torch.nn.Module): + __jit_ignored_attributes__ = ["a"] + + def __init__(self, a, b): + super().__init__() + self.a = a + self.b = b + + def forward(self, x): + return x + + a_with_linear = A(torch.nn.Linear(5, 5), 5) + a_with_string = A("string", 10) + + # Both should have the same type because the attribute + # that differs in type is ignored and the common attribute + # has the same type. + self.assertSameType(a_with_linear, a_with_string) + + def test_type_not_shared_ignored_attributes(self): + """ + Test that types are not shared if the exclusion of their + ignored attributes makes them not equal. + """ + class A(torch.nn.Module): + __jit_ignored_attributes__ = ["a"] + + def __init__(self, a, b, c): + super().__init__() + self.a = a + self.b = b + self.c = c + + def forward(self, x): + return x + + mod = A(torch.nn.Linear(5, 5), 5, "string") + s1 = torch.jit.script(mod) + A.__jit_ignored_attributes__ = ["a", "b"] + s2 = torch.jit.script(mod) + + # The types of s1 and s2 should differ. Although they are instances + # of A, __jit_ignored_attributes__ was modified before scripting s2, + # so the set of ignored attributes is different between s1 and s2. + self.assertDifferentType(s1, s2) diff --git a/test/test_jit_py3.py b/test/test_jit_py3.py index 212b03d9658b..9a1371e2d0c8 100644 --- a/test/test_jit_py3.py +++ b/test/test_jit_py3.py @@ -678,6 +678,56 @@ def attr(self): with self.assertRaisesRegex(torch.nn.modules.module.ModuleAttributeError, "has no attribute"): scripted_mod.ignored_attr + def test_ignoring_module_attributes(self): + """ + Test that module attributes can be ignored. + """ + class Sub(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, a: int) -> int: + return sum([a]) + + class ModuleWithIgnoredAttr(torch.nn.Module): + __jit_ignored_attributes__ = ["a", "sub"] + + def __init__(self, a: int, b: int): + super().__init__() + self.a = a + self.b = b + self.sub = Sub() + + def forward(self) -> int: + return self.b + + @torch.jit.ignore + def ignored_fn(self) -> int: + return self.sub.forward(self.a) + + mod = ModuleWithIgnoredAttr(1, 4) + scripted_mod = torch.jit.script(mod) + self.assertEqual(scripted_mod(), 4) + self.assertEqual(scripted_mod.ignored_fn(), 1) + + # Test the error message for ignored attributes. + class ModuleUsesIgnoredAttr(torch.nn.Module): + __jit_ignored_attributes__ = ["a", "sub"] + + def __init__(self, a: int): + super().__init__() + self.a = a + self.sub = Sub() + + def forward(self) -> int: + return self.sub(self.b) + + mod = ModuleUsesIgnoredAttr(1) + + with self.assertRaisesRegexWithHighlight(RuntimeError, r"attribute was ignored during compilation", "self.sub"): + scripted_mod = torch.jit.script(mod) + + def test_export_opnames_interface(self): global OneTwoModule diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 9ccc5f7cb899..f1e96e31d994 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -284,6 +284,8 @@ class ConcreteModuleTypeBuilder: def add_builtin_function(self, name: str, symbol_name: str): ... def add_failed_attribute(self, name: str, failure_reason: str): ... def add_function_attribute(self, name: str, ty: JitType, func: Callable[..., Any]): ... + def add_ignored_attribute(self, name: str): ... + def add_ignored_attributes(self, names: List[str]): ... class ConcreteModuleType: def get_constants(self) -> Dict[str, Any]: ... diff --git a/torch/csrc/jit/frontend/concrete_module_type.cpp b/torch/csrc/jit/frontend/concrete_module_type.cpp index 169b589cbfff..058d3bd58a05 100644 --- a/torch/csrc/jit/frontend/concrete_module_type.cpp +++ b/torch/csrc/jit/frontend/concrete_module_type.cpp @@ -106,6 +106,7 @@ bool ConcreteModuleTypeBuilder::equals( bool equal = pyClass_.is(other.pyClass_) && iterableModuleKind_ == other.iterableModuleKind_ && + ignoredAttributes_ == other.ignoredAttributes_ && constants_ == other.constants_ && attributes_ == other.attributes_ && overloads_ == other.overloads_ && @@ -186,6 +187,10 @@ c10::optional ConcreteModuleType::findFailedAttribute( return c10::nullopt; } +bool ConcreteModuleType::isIgnoredAttribute(const std::string& name) const { + return data_.ignoredAttributes_.count(name) > 0; +} + std::shared_ptr ConcreteModuleType:: findSubmoduleConcreteType(const std::string& name) const { const auto it = std::find_if( @@ -281,6 +286,10 @@ void ConcreteModuleTypeBuilder::addFailedAttribute( failedAttributes_.emplace(std::move(name), std::move(failureReason)); } +void ConcreteModuleTypeBuilder::addIgnoredAttribute(std::string name) { + ignoredAttributes_.emplace(std::move(name)); +} + void ConcreteModuleType::dump() const { std::cout << "ConcreteModuleType for: " << py::getattr(data_.pyClass_, "__name__") << "\n"; diff --git a/torch/csrc/jit/frontend/concrete_module_type.h b/torch/csrc/jit/frontend/concrete_module_type.h index 0410693d439c..1b1a9142f73e 100644 --- a/torch/csrc/jit/frontend/concrete_module_type.h +++ b/torch/csrc/jit/frontend/concrete_module_type.h @@ -81,6 +81,7 @@ class VISIBILITY_HIDDEN ConcreteModuleTypeBuilder { std::vector overloadedMethodNames); void addBuiltinFunction(std::string name, std::string symbol_name); void addFailedAttribute(std::string name, std::string failureReason); + void addIgnoredAttribute(std::string name); void setIterableModuleKind(IterableModuleKind kind); // If a ConcreteModuleType is poisoned, it will never compare equal to any @@ -150,6 +151,9 @@ class VISIBILITY_HIDDEN ConcreteModuleTypeBuilder { // Any attributes we failed to convert to TorchScript, along with a hint as to // why std::unordered_map failedAttributes_; + // Any attributes that were marked as ignored. They cannot be used in + // TorchScript but can still be used in ignored function in Python. + std::unordered_set ignoredAttributes_; // Any function attributes. These are special right now because functions are // not first-class in the type system. std::unordered_map functionAttributes_; @@ -191,6 +195,7 @@ class VISIBILITY_HIDDEN ConcreteModuleType { std::shared_ptr findSubmoduleConcreteType( const std::string& name) const; c10::optional findFailedAttribute(const std::string& name) const; + bool isIgnoredAttribute(const std::string& name) const; // These getters are only here to return things as types that can be // automatically converted by pybind. diff --git a/torch/csrc/jit/python/python_sugared_value.cpp b/torch/csrc/jit/python/python_sugared_value.cpp index 119b6b5e5de7..a99f706bce68 100644 --- a/torch/csrc/jit/python/python_sugared_value.cpp +++ b/torch/csrc/jit/python/python_sugared_value.cpp @@ -586,6 +586,8 @@ std::shared_ptr ModuleValue::attr( std::string hint; if (auto failureReason = concreteType_->findFailedAttribute(field)) { hint = *failureReason; + } else if (concreteType_->isIgnoredAttribute(field)) { + hint = "attribute was ignored during compilation"; } throw ErrorReport(loc) diff --git a/torch/csrc/jit/python/script_init.cpp b/torch/csrc/jit/python/script_init.cpp index 5e4031fdf435..30a1fce15b0a 100644 --- a/torch/csrc/jit/python/script_init.cpp +++ b/torch/csrc/jit/python/script_init.cpp @@ -1564,6 +1564,17 @@ void initJitScriptBindings(PyObject* module) { .def( "add_failed_attribute", &ConcreteModuleTypeBuilder::addFailedAttribute) + .def( + "add_ignored_attribute", + &ConcreteModuleTypeBuilder::addIgnoredAttribute) + .def( + "add_ignored_attributes", + [](ConcreteModuleTypeBuilder& self, + const std::vector& names) { + for (auto& name : names) { + self.addIgnoredAttribute(name); + } + }) .def( "set_module_dict", [](ConcreteModuleTypeBuilder& self) { @@ -1589,6 +1600,7 @@ void initJitScriptBindings(PyObject* module) { .def("get_attributes", &ConcreteModuleType::getAttributesPy) .def("get_modules", &ConcreteModuleType::getModulesPy) .def("dump", &ConcreteModuleType::dump) + .def("is_ignored_attribute", &ConcreteModuleType::isIgnoredAttribute) .def( "equals", [](const ConcreteModuleType& self, const ConcreteModuleType& other) { diff --git a/torch/jit/_recursive.py b/torch/jit/_recursive.py index 0eb423516f6f..1237885f3ee4 100644 --- a/torch/jit/_recursive.py +++ b/torch/jit/_recursive.py @@ -105,6 +105,10 @@ def infer_concrete_type_builder(nn_module, share_types=True): class_annotations = getattr(nn_module, '__annotations__', {}) + # Get user-annotated ignored attributes. + user_annotated_ignored_attributes = getattr(nn_module, "__jit_ignored_attributes__", list()) + concrete_type_builder.add_ignored_attributes(user_annotated_ignored_attributes) + # try to infer the type from type annotation or from the object itself def infer_type(name, item): # The forward function from Module is special; never use this annotations; we @@ -123,6 +127,9 @@ def infer_type(name, item): added_names = set() for name, item in nn_module._parameters.items(): + if name in user_annotated_ignored_attributes: + continue + assert item is None or isinstance(item, torch.Tensor) attr_type = infer_type(name, item) # We currently have the invariant in various places in our code @@ -134,12 +141,18 @@ def infer_type(name, item): added_names.add(name) for name, item in nn_module._buffers.items(): + if name in user_annotated_ignored_attributes: + continue + assert item is None or isinstance(item, torch.Tensor) attr_type = infer_type(name, item) concrete_type_builder.add_attribute(name, attr_type, False, True) added_names.add(name) for name, item in nn_module._modules.items(): + if name in user_annotated_ignored_attributes: + continue + attr_type = infer_type(name, item) if item is None: # Modules can be None. We don't have direct support for optional @@ -205,6 +218,9 @@ def infer_type(name, item): # PyTorch adds a few more. Prevent these from getting compiled. continue + if name in user_annotated_ignored_attributes: + continue + if name in added_names: # Don't re-add anything we already added continue @@ -390,7 +406,7 @@ def init_fn(script_module): cpp_module.setattr(name, scripted) script_module._modules[name] = scripted - # 3. Copy @ignored/@unused methods from the original `nn_module` to the new ScriptModule. + # 3. Copy @ignored/@unused methods and attrs from the original `nn_module` to the new ScriptModule. # This ensures we can access these Python methods on the ScriptModule. for name in dir(nn_module): item = getattr(nn_module, name, None) @@ -398,6 +414,8 @@ def init_fn(script_module): unbound_function = getattr(type(nn_module), name) bound_method = unbound_function.__get__(script_module) setattr(script_module, name, bound_method) + elif concrete_type.is_ignored_attribute(name): + setattr(script_module, name, item) # For convenience, attach the concrete type to the new ScriptModule script_module._concrete_type = concrete_type From 275bb5e80139efc6f779b09d342b39eacccd97a3 Mon Sep 17 00:00:00 2001 From: Rong Rong Date: Tue, 6 Oct 2020 18:07:58 -0700 Subject: [PATCH 17/69] Fix flakiness in caffe2/test:serialization - test_serialization_new_format_old_format_compat (#45915) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45915 Use temp file instead Test Plan: buck test mode/opt-asan //caffe2/test:serialization -- 'test_serialization_new_format_old_format_compat \(test_serialization\.TestBothSerialization\)' --run-disabled --jobs 18 --stress-runs 10 --record-results Reviewed By: malfet Differential Revision: D24142278 fbshipit-source-id: 9c88330fc5664d464daa9124e67644f497353f3b --- test/test_serialization.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/test/test_serialization.py b/test/test_serialization.py index 9b30e4690540..5c40c1285c03 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -584,16 +584,21 @@ def __exit__(self, *args, **kwargs): torch.save = self.torch_save class TestBothSerialization(TestCase, SerializationMixin): + @unittest.skipIf(IS_WINDOWS, "NamedTemporaryFile on windows") def test_serialization_new_format_old_format_compat(self): x = [torch.ones(200, 200) for i in range(30)] - torch.save(x, "big_tensor.zip", _use_new_zipfile_serialization=True) - x_new_load = torch.load("big_tensor.zip") - self.assertEqual(x, x_new_load) - - torch.save(x, "big_tensor.zip", _use_new_zipfile_serialization=False) - x_old_load = torch.load("big_tensor.zip") - self.assertEqual(x_old_load, x_new_load) - os.remove("big_tensor.zip") + + def test(filename): + torch.save(x, filename, _use_new_zipfile_serialization=True) + x_new_load = torch.load(filename) + self.assertEqual(x, x_new_load) + + torch.save(x, filename, _use_new_zipfile_serialization=False) + x_old_load = torch.load(filename) + self.assertEqual(x_old_load, x_new_load) + + with tempfile.NamedTemporaryFile() as f: + test(f.name) class TestOldSerialization(TestCase, SerializationMixin): From e8d8de32b427ad3efd39b0c63f44e5a5b8b786d5 Mon Sep 17 00:00:00 2001 From: Hao Lu Date: Tue, 6 Oct 2020 20:52:29 -0700 Subject: [PATCH 18/69] [StaticRuntime] Implement StaticRuntime::benchmark (#45639) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45639 `StaticRuntime::run_individual` is to mimic the caffe2 operator benchmark `SimpleNet::TEST_Benchmark`, so we can accurate information on the operator breakdown. We found that the PyTorch AutogradProfiler adds a lot of overhead to small models, such as the adindexer precomputation_merge net, 100% for batch_size 1, 33% for batch_size 20. This implementation adds very little overhead, as shown in the test plan. Test Plan: Test results are fb internal only. Reviewed By: yinghai, dzhulgakov Differential Revision: D24012088 fbshipit-source-id: f32eb420aace93e2de421a15e4209fce6a3d90f0 --- test/test_static_runtime.py | 43 ++++++++- torch/csrc/jit/runtime/static/impl.cpp | 117 +++++++++++++++++++++++++ torch/csrc/jit/runtime/static/impl.h | 27 ++++++ torch/csrc/jit/runtime/static/init.cpp | 58 +++++++++++- torch/csrc/jit/runtime/static/ops.cpp | 16 ++-- 5 files changed, 247 insertions(+), 14 deletions(-) diff --git a/test/test_static_runtime.py b/test/test_static_runtime.py index 86dafa3903dd..38feea5f4503 100644 --- a/test/test_static_runtime.py +++ b/test/test_static_runtime.py @@ -1,7 +1,6 @@ +import numpy as np import torch from torch import nn -import numpy as np - from torch.testing._internal.common_utils import TestCase, run_tests @@ -13,8 +12,19 @@ def __init__(self, scripted): else: self.static_runtime = torch._C._jit_to_static_runtime(scripted.graph) - def __call__(self, *inps): - return self.static_runtime.run(inps) + def __call__(self, *args, **kwargs): + if not kwargs: + return self.static_runtime.run(args) + else: + return self.static_runtime.run(args, kwargs) + + def benchmark(self, args, kwargs, warmup_runs, main_runs): + self.static_runtime.benchmark(args, kwargs, warmup_runs, main_runs) + + def benchmark_individual_ops(self, args, kwargs, warmup_runs, main_runs): + return self.static_runtime.benchmark_individual_ops( + args, kwargs, warmup_runs, main_runs + ) def linear_shim(input, weight, bias=None): @@ -117,8 +127,33 @@ def test_multihead_attention_layer(self): attention_a = StaticRuntime(attention) o_test = attention_a(src, src, src, src_mask) + o_test_kw = attention_a(src, src, value=src, mask=src_mask) for a, b in zip(o_ref, o_test): torch.testing.assert_allclose(a, b) + for a, b in zip(o_ref, o_test_kw): + torch.testing.assert_allclose(a, b) + + def test_multihead_attention_layer_benchmark(self): + HID_DIM = 256 + QUERY_LEN = 8 + BATCH_SIZE = 128 + LAYERS = 3 + HEADS = 8 + DROPOUT = 0.1 + device = torch.device("cpu") + attention = MultiHeadAttentionLayer(HID_DIM, HEADS, DROPOUT, device).to(device) + with torch.no_grad(): + src = torch.randn(BATCH_SIZE, QUERY_LEN, HID_DIM).to(device) + src_mask = (src > 0)[:, :, 0].unsqueeze(1).unsqueeze(2).to(device) + + attention.eval() + attention = torch.jit.script(attention) + attention_a = StaticRuntime(attention) + + attention_a.benchmark([src, src, src, src_mask], {}, 10, 10) + metrics = attention_a.benchmark_individual_ops( + [src, src, src, src_mask], {}, 10, 10 + ) def test_mlp(self): # Arguments taken from benchmark script, ./bench/dlrm_s_benchmark.sh diff --git a/torch/csrc/jit/runtime/static/impl.cpp b/torch/csrc/jit/runtime/static/impl.cpp index e63c0be66454..36fe7a1225d6 100644 --- a/torch/csrc/jit/runtime/static/impl.cpp +++ b/torch/csrc/jit/runtime/static/impl.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include #include @@ -125,6 +126,122 @@ c10::IValue StaticRuntime::run( return workspace_[graph_->outputs().at(0)]; } +void StaticRuntime::benchmark( + const std::vector& args, + const std::unordered_map& kwargs, + const int warmup_runs, + const int main_runs) const { + float time_per_iter = benchmark_model(args, kwargs, warmup_runs, main_runs); + std::cout << "Static runtime ms per iter: " << time_per_iter + << ". Iters per second: " << 1000.0 / time_per_iter << std::endl; + + IndividualMetrics results = + benchmark_individual_ops(args, kwargs, warmup_runs, main_runs); + std::cout << "Setting up took " << results.setup_time << " ms" << std::endl; + + for (size_t i = 0; i < nodes_.size(); i++) { + const Node* node = nodes_[i].get_node(); + std::cout << "Node #" << i << ": " << results.time_per_node[i] + << " ms/iter, "; + node->print(std::cout, 0, nullptr, false); + } + + std::vector> time_per_node_type_vec{ + results.time_per_node_type.begin(), results.time_per_node_type.end()}; + std::sort( + time_per_node_type_vec.begin(), + time_per_node_type_vec.end(), + [](auto& left, auto& right) { return left.second > right.second; }); + + std::cout << "Time per node type:" << std::endl; + for (const auto& p : time_per_node_type_vec) { + const std::string& kind = p.first; + const double ms = p.second; + std::cout << std::setw(15) << ms << " ms. " << std::setw(10) + << results.percent_per_node_type[kind] << "%. " << kind << " (" + << results.instances_per_node_type[kind] << " nodes)" + << std::endl; + } + std::cout << std::setw(15) << results.total_time << " ms. in Total" + << std::endl; +} + +float StaticRuntime::benchmark_model( + const std::vector& args, + const std::unordered_map& kwargs, + const int warmup_runs, + const int main_runs) const { + TORCH_CHECK(warmup_runs >= 0 && main_runs >= 1); + + for (int i = 0; i < warmup_runs; i++) { + run(args, kwargs); + } + caffe2::Timer timer; + for (int i = 0; i < main_runs; i++) { + run(args, kwargs); + } + float millis = timer.MilliSeconds(); + return millis / static_cast(main_runs); +} + +StaticRuntime::IndividualMetrics StaticRuntime::benchmark_individual_ops( + const std::vector& args, + const std::unordered_map& kwargs, + const int warmup_runs, + const int main_runs) const { + TORCH_CHECK(warmup_runs >= 0 && main_runs >= 1); + + IndividualMetrics results; + results.total_time = 0.0; + results.time_per_node.resize(nodes_.size(), 0); + + // setup time + caffe2::Timer timer; + std::vector stack(args); + if (!kwargs.empty()) { + // This is not ideal + TORCH_INTERNAL_ASSERT( + schema_ != nullptr, + "Schema is not available. Consider creating the Static Runtime " + "with StaticRuntime(const torch::jit::Module& m) instead."); + schema_->checkAndNormalizeInputs(stack, kwargs); + } + for (size_t i = 0; i < stack.size(); i++) { + workspace_[graph_->inputs()[i]] = stack[i]; + } + results.setup_time = timer.MilliSeconds(); + + // warmup runs + for (int i = 0; i < warmup_runs; i++) { + run(args, kwargs); + } + + // main runs + for (int i = 0; i < main_runs; i++) { + for (size_t j = 0; j < nodes_.size(); j++) { + timer.Start(); + nodes_[j].run(workspace_); + float millis = timer.MilliSeconds(); + results.time_per_node[j] += millis; + } + } + + // post processing + for (size_t i = 0; i < nodes_.size(); i++) { + const Node* node = nodes_[i].get_node(); + std::string kind = std::string(node->kind().toQualString()); + results.time_per_node[i] /= static_cast(main_runs); + results.time_per_node_type[kind] += results.time_per_node[i]; + results.instances_per_node_type[kind]++; + results.total_time += results.time_per_node[i]; + } + for (const auto& p : results.time_per_node_type) { + const std::string& kind = p.first; + results.percent_per_node_type[kind] = p.second / results.total_time * 100; + } + return results; +} + ProcessedNode::ProcessedNode(Node* node) : node_(node) { if (node->kind() != prim::ListConstruct && node->kind() != prim::TupleConstruct && diff --git a/torch/csrc/jit/runtime/static/impl.h b/torch/csrc/jit/runtime/static/impl.h index 270251bc265d..2703da4cf122 100644 --- a/torch/csrc/jit/runtime/static/impl.h +++ b/torch/csrc/jit/runtime/static/impl.h @@ -34,6 +34,33 @@ class TORCH_API StaticRuntime { const std::vector& args, const std::unordered_map& kwargs) const; + void benchmark( + const std::vector& args, + const std::unordered_map& kwargs, + const int warmup_runs, + const int main_runs) const; + + float benchmark_model( + const std::vector& args, + const std::unordered_map& kwargs, + const int warmup_runs, + const int main_runs) const; + + struct IndividualMetrics { + float setup_time; + float total_time; + std::vector time_per_node; + std::unordered_map time_per_node_type; + std::unordered_map percent_per_node_type; + std::unordered_map instances_per_node_type; + }; + + IndividualMetrics benchmark_individual_ops( + const std::vector& args, + const std::unordered_map& kwargs, + const int warmup_runs, + const int main_runs) const; + #ifdef FBCODE_CAFFE2 using ConstantMap = folly::F14FastMap; #else diff --git a/torch/csrc/jit/runtime/static/init.cpp b/torch/csrc/jit/runtime/static/init.cpp index f55ca1a2801b..5e404c181275 100644 --- a/torch/csrc/jit/runtime/static/init.cpp +++ b/torch/csrc/jit/runtime/static/init.cpp @@ -6,18 +6,70 @@ namespace jit { void initStaticRuntimeBindings(PyObject* module) { auto m = py::handle(module).cast(); - py::class_(m, "StaticRuntime") + py::class_ static_runtime(m, "StaticRuntime"); + py::class_( + static_runtime, "IndividualMetrics") + .def_readonly("setup_time", &StaticRuntime::IndividualMetrics::setup_time) + .def_readonly("total_time", &StaticRuntime::IndividualMetrics::total_time) + .def_readonly( + "time_per_node", &StaticRuntime::IndividualMetrics::time_per_node) + .def_readonly( + "time_per_node_type", + &StaticRuntime::IndividualMetrics::time_per_node_type) + .def_readonly( + "percent_per_node_type", + &StaticRuntime::IndividualMetrics::percent_per_node_type) + .def_readonly( + "instances_per_node_type", + &StaticRuntime::IndividualMetrics::instances_per_node_type); + static_runtime .def( "run", py::overload_cast&>( - &StaticRuntime::run, py::const_)); + &StaticRuntime::run, py::const_)) + .def( + "run", + [](StaticRuntime& self, + const std::vector& args, + const std::unordered_map& kwargs) { + std::vector arg_ivalues{args.begin(), args.end()}; + std::unordered_map kwarg_ivalues{ + kwargs.begin(), kwargs.end()}; + c10::IValue ret = self.run(arg_ivalues, kwarg_ivalues); + return toPyObject(ret); + }) + .def( + "benchmark", + [](StaticRuntime& self, + const std::vector& args, + const std::unordered_map& kwargs, + const int warmup_runs, + const int main_runs) { + std::vector arg_ivalues{args.begin(), args.end()}; + std::unordered_map kwarg_ivalues{ + kwargs.begin(), kwargs.end()}; + self.benchmark(arg_ivalues, kwarg_ivalues, warmup_runs, main_runs); + }) + .def( + "benchmark_individual_ops", + [](StaticRuntime& self, + const std::vector& args, + const std::unordered_map& kwargs, + const int warmup_runs, + const int main_runs) { + std::vector arg_ivalues{args.begin(), args.end()}; + std::unordered_map kwarg_ivalues{ + kwargs.begin(), kwargs.end()}; + return self.benchmark_individual_ops( + arg_ivalues, kwarg_ivalues, warmup_runs, main_runs); + }); m.def( "_jit_to_static_runtime", [](const std::shared_ptr& g) { return StaticRuntime(PrepareForStaticRuntime(g)); }) .def("_jit_to_static_runtime", [](const torch::jit::Module& m) { - return StaticRuntime(PrepareForStaticRuntime(m)); + return StaticRuntime(m); }); } diff --git a/torch/csrc/jit/runtime/static/ops.cpp b/torch/csrc/jit/runtime/static/ops.cpp index 19c783c8a996..fe91920f3c11 100644 --- a/torch/csrc/jit/runtime/static/ops.cpp +++ b/torch/csrc/jit/runtime/static/ops.cpp @@ -6,14 +6,16 @@ namespace torch { namespace jit { bool canRunOutOfPlace(Node* n) { + static std::unordered_set out_of_place_nodes{"aten::add", + "aten::mul", + "aten::addmm" + "aten::bmm", + "aten::sigmoid", + "aten::cat", + "aten::transpose", + "aten::flatten"}; auto str = std::string(n->kind().toQualString()); - if ((str == "aten::add") || (str == "aten::mul") || (str == "aten::addmm") || - (str == "aten::bmm") || (str == "aten::sigmoid") || - (str == "aten::cat") || (str == "aten::transpose") || - (str == "aten::flatten")) { - return true; - } - return false; + return out_of_place_nodes.count(str) > 0; } std::function getOutOfPlaceOperation( From 5c283fa29243e76947c0b7db2fdcff58aa59d2e6 Mon Sep 17 00:00:00 2001 From: Supriya Rao Date: Tue, 6 Oct 2020 21:04:22 -0700 Subject: [PATCH 19/69] [quant] Add 4-bit embedding_bag prepack/unpack support using quint4x2 (#45751) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45751 Use the torch.quint4x2 dtype to create 4-bit packed tensors Test Plan: python test/test_quantization.py TestEmbeddingBagOps Imported from OSS Reviewed By: z-a-f Differential Revision: D24120997 fbshipit-source-id: 6aba2985715a346f6894cf43d5794e104a9ab061 --- .../quantized/cpu/qembeddingbag_prepack.cpp | 97 +++++++++++++------ .../quantized/cpu/qembeddingbag_unpack.cpp | 86 +++++++++++----- test/quantization/test_quantized_op.py | 9 +- 3 files changed, 135 insertions(+), 57 deletions(-) diff --git a/aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.cpp b/aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.cpp index e94f0be0d802..bf17fe172a76 100644 --- a/aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.cpp +++ b/aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.cpp @@ -14,7 +14,6 @@ torch::class_ register_embedding_params(); * To prepack the weights we store the scale and bias (where bias is Xmin) * for each row along with the quantized weights. */ -// TODO: Extend this to support 4-bits once 4-bit qtensor support is added. c10::intrusive_ptr PackedEmbeddingBagWeight::prepack( at::Tensor qweight) { static constexpr int64_t version = 1; @@ -22,13 +21,24 @@ c10::intrusive_ptr PackedEmbeddingBagWeight::prepack( qweight.dim() == 2, "quantized::embedding_bag_prepack weight tensor rank should be 2"); TORCH_CHECK( - qweight.scalar_type() == c10::kQUInt8, - "qembedding_bag_prepack currently only supports quint8 weights"); + qweight.scalar_type() == c10::kQUInt8 || + qweight.scalar_type() == c10::kQUInt4x2, + "qembedding_bag_prepack currently only supports quint8 and quint4x2 weights"); at::Tensor weight_contig = qweight.contiguous(qweight.suggest_memory_format()); - const uint8_t* weight_data = - reinterpret_cast(weight_contig.data_ptr()); + + int bit_width, scale_bias_bytes; + uint8_t* weight_data = static_cast(weight_contig.data_ptr()); + if (qweight.scalar_type() == c10::kQUInt8) { + bit_width = 8; + scale_bias_bytes = 8; // extra 8 bytes to store FP scale and bias per row. + } else { + bit_width = 4; + scale_bias_bytes = + 4; // extra 4 bytes to store at::Half scale and bias per row. + } + const auto num_elem_per_byte = 8 / bit_width; int64_t embedding_rows = qweight.size(0); int64_t embedding_cols = qweight.size(1); @@ -50,8 +60,9 @@ c10::intrusive_ptr PackedEmbeddingBagWeight::prepack( std::vector output_shape = { embedding_rows, - embedding_cols + - 8}; // extra 8 bytes to store FP scale and zero_point per row. + static_cast( + (embedding_cols + num_elem_per_byte - 1) / num_elem_per_byte + + scale_bias_bytes)}; // extra bytes to store scale and bias per row. size_t output_columns = output_shape[1]; // Allocate output packed weights. @@ -61,28 +72,46 @@ c10::intrusive_ptr PackedEmbeddingBagWeight::prepack( weight_contig.suggest_memory_format()); auto* output_data = output.data_ptr(); - at::parallel_for( - 0, embedding_rows, 1, [&](int32_t start_idx, int32_t end_idx) { - for (int64_t row = start_idx; row < end_idx; ++row) { - const uint8_t* input_row = weight_data + row * embedding_cols; - std::uint8_t* output_row = output_data + row * output_columns; - float* output_row_scale_bias = - reinterpret_cast(output_row + embedding_cols); - output_row_scale_bias[0] = weight_scales[row]; - output_row_scale_bias[1] = weight_bias[row]; - for (int64_t col = 0; col < embedding_cols; ++col) { - output_row[col] = input_row[col]; + if (bit_width == 8) { + at::parallel_for( + 0, embedding_rows, 1, [&](int32_t start_idx, int32_t end_idx) { + for (int64_t row = start_idx; row < end_idx; ++row) { + const uint8_t* input_row = weight_data + row * embedding_cols; + std::uint8_t* output_row = output_data + row * output_columns; + float* output_row_scale_bias = + reinterpret_cast(output_row + embedding_cols); + output_row_scale_bias[0] = weight_scales[row]; + output_row_scale_bias[1] = weight_bias[row]; + for (int64_t col = 0; col < embedding_cols; ++col) { + output_row[col] = input_row[col]; + } } - } - }); + }); + } else { + // Re-calculate the number of embedding_cols, to account for values packed + // in a byte. + embedding_cols = + (embedding_cols + num_elem_per_byte - 1) / num_elem_per_byte; + at::parallel_for( + 0, embedding_rows, 1, [&](int32_t start_idx, int32_t end_idx) { + for (int64_t row = start_idx; row < end_idx; ++row) { + const uint8_t* input_row = weight_data + row * embedding_cols; + std::uint8_t* output_row = output_data + row * output_columns; + at::Half* output_row_scale_bias = + reinterpret_cast(output_row + embedding_cols); + output_row_scale_bias[0] = weight_scales[row]; + output_row_scale_bias[1] = weight_bias[row]; + for (int64_t col = 0; col < embedding_cols; ++col) { + // The weight values have already been packed, so here we just + // store it in the output tensor. + output_row[col] = input_row[col]; + } + } + }); + } auto packed_ptr = c10::make_intrusive( - output, - weight_scales, - weight_zero_points, - 8 /* bit rate */, - qtype, - version); + output, weight_scales, weight_zero_points, bit_width, qtype, version); return packed_ptr; } @@ -290,13 +319,21 @@ class QEmbeddingPackWeights final { }; TORCH_LIBRARY_IMPL(quantized, CPU, m) { - m.impl(TORCH_SELECTIVE_NAME("quantized::embedding_bag_byte_prepack"), TORCH_FN(qembeddingbag_byte_prepack)); - m.impl(TORCH_SELECTIVE_NAME("quantized::embedding_bag_4bit_prepack"), TORCH_FN(qembeddingbag_4bit_prepack)); - m.impl(TORCH_SELECTIVE_NAME("quantized::embedding_bag_2bit_prepack"), TORCH_FN(qembeddingbag_2bit_prepack)); + m.impl( + TORCH_SELECTIVE_NAME("quantized::embedding_bag_byte_prepack"), + TORCH_FN(qembeddingbag_byte_prepack)); + m.impl( + TORCH_SELECTIVE_NAME("quantized::embedding_bag_4bit_prepack"), + TORCH_FN(qembeddingbag_4bit_prepack)); + m.impl( + TORCH_SELECTIVE_NAME("quantized::embedding_bag_2bit_prepack"), + TORCH_FN(qembeddingbag_2bit_prepack)); } TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) { - m.impl(TORCH_SELECTIVE_NAME("quantized::embedding_bag_prepack"), TORCH_FN(QEmbeddingPackWeights::run)); + m.impl( + TORCH_SELECTIVE_NAME("quantized::embedding_bag_prepack"), + TORCH_FN(QEmbeddingPackWeights::run)); } } // namespace diff --git a/aten/src/ATen/native/quantized/cpu/qembeddingbag_unpack.cpp b/aten/src/ATen/native/quantized/cpu/qembeddingbag_unpack.cpp index ca3d9dc71c7e..86c66b64a410 100644 --- a/aten/src/ATen/native/quantized/cpu/qembeddingbag_unpack.cpp +++ b/aten/src/ATen/native/quantized/cpu/qembeddingbag_unpack.cpp @@ -9,38 +9,69 @@ torch::class_ register_embedding_params(); at::Tensor PackedEmbeddingBagWeight::unpack() { auto packed_weight = packed_w; at::Tensor weight_origin; - if (bit_rate_ == 8) { + + if (bit_rate_ == 8 || bit_rate_ == 4) { const auto input_rows = packed_weight.size(0); const auto input_columns = packed_weight.size(1); - - // The last 2 values are used to store the FP32 scale and zero_point values - // per row. - int output_columns = input_columns - 2 * sizeof(float); + int scale_bias_bytes; + const auto num_elem_per_byte = 8 / bit_rate_; + if (bit_rate_ == 8) { + // The last 2 values are used to store the FP32 scale and zero_point + // values per row. + scale_bias_bytes = 8; + } else { + scale_bias_bytes = 4; + } const auto* input = packed_weight.data_ptr(); - std::vector output_shape = {input_rows, output_columns}; + // Calculate the output shape, accounting for the last n bytes to be used + // for scale/bias rest of the entries are packed depending on the bit_width. + std::vector output_shape = { + input_rows, + static_cast(input_columns - scale_bias_bytes) * + num_elem_per_byte}; auto scales = at::from_blob( w_scale.data(), w_scale.size(), device(c10::kCPU).dtype(c10::kFloat)); auto zero_points = at::from_blob( w_zp.data(), w_zp.size(), device(c10::kCPU).dtype(c10::kFloat)); - weight_origin = at::_empty_per_channel_affine_quantized( - output_shape, - scales.toType(c10::kFloat), - zero_points.toType(c10::kFloat), - 0, // The output channel axis is 0 - device(c10::kCPU).dtype(c10::kQUInt8)); - - uint8_t* output_data = - reinterpret_cast(weight_origin.data_ptr()); - + auto output_columns = output_shape[1]; + uint8_t* output_data; + + // Allocate output weight tensor based on the bit_width + if (bit_rate_ == 8) { + weight_origin = at::_empty_per_channel_affine_quantized( + output_shape, + scales.toType(c10::kFloat), + zero_points.toType(c10::kFloat), + 0, // The output channel axis is 0 + device(c10::kCPU).dtype(c10::kQUInt8)); + output_data = static_cast(weight_origin.data_ptr()); + } else { + // We create empty qtensor with the full output shape, and dtype set to + // quint4x2 This will internally allocate appropriate storage bytes to + // account for the packed nature of this dtype. + weight_origin = at::_empty_per_channel_affine_quantized( + output_shape, + scales.toType(c10::kFloat), + zero_points.toType(c10::kFloat), + 0, // The output channel axis is 0 + device(c10::kCPU).dtype(c10::kQUInt4x2)); + output_data = static_cast(weight_origin.data_ptr()); + } + + // Copy over the data from the packed weight to the output. + // For sub-byte tensors this will copy the packed bytes over since the + // sub_byte qtensors are expected to store data in packed format. at::parallel_for(0, input_rows, 1, [&](int32_t start_idx, int32_t end_idx) { for (int64_t row = start_idx; row < end_idx; ++row) { const std::uint8_t* input_row = input + row * input_columns; - uint8_t* output_row = output_data + row * output_columns; + uint8_t* output_row = + output_data + row * output_columns / num_elem_per_byte; - for (std::size_t col = 0; col < output_columns; ++col) { + for (std::size_t col = 0; col < output_columns / num_elem_per_byte; + ++col) { output_row[col] = input_row[col]; } // output_columns } @@ -49,7 +80,8 @@ at::Tensor PackedEmbeddingBagWeight::unpack() { return weight_origin; } TORCH_INTERNAL_ASSERT( - "Currently only supporting 8-bit quantization of embedding bag."); + false, + "We currently only support 8-bit and 4-bit quantization of embedding_bag."); return weight_origin; } @@ -171,15 +203,23 @@ class QEmbeddingUnpackWeights final { }; TORCH_LIBRARY_IMPL(quantized, CPU, m) { - m.impl(TORCH_SELECTIVE_NAME("quantized::embedding_bag_byte_unpack"), qembeddingbag_byte_unpack); - m.impl(TORCH_SELECTIVE_NAME("quantized::embedding_bag_4bit_unpack"), qembeddingbag_4bit_unpack); - m.impl(TORCH_SELECTIVE_NAME("quantized::embedding_bag_2bit_unpack"), qembeddingbag_2bit_unpack); + m.impl( + TORCH_SELECTIVE_NAME("quantized::embedding_bag_byte_unpack"), + qembeddingbag_byte_unpack); + m.impl( + TORCH_SELECTIVE_NAME("quantized::embedding_bag_4bit_unpack"), + qembeddingbag_4bit_unpack); + m.impl( + TORCH_SELECTIVE_NAME("quantized::embedding_bag_2bit_unpack"), + qembeddingbag_2bit_unpack); } TORCH_LIBRARY_IMPL(quantized, CatchAll, m) { // Unpack the packed embedding_bag weights using TorchBind custom class. // TODO extend to support 4-bit qtensor. - m.impl(TORCH_SELECTIVE_NAME("quantized::embedding_bag_unpack"), TORCH_FN(QEmbeddingUnpackWeights::run)); + m.impl( + TORCH_SELECTIVE_NAME("quantized::embedding_bag_unpack"), + TORCH_FN(QEmbeddingUnpackWeights::run)); } } // namespace diff --git a/test/quantization/test_quantized_op.py b/test/quantization/test_quantized_op.py index ceef43dca51c..d11005b51ddd 100644 --- a/test/quantization/test_quantized_op.py +++ b/test/quantization/test_quantized_op.py @@ -2793,23 +2793,24 @@ class TestQuantizedEmbeddingOps(TestCase): def _test_embedding_bag_unpack_fn(self, pack_fn, unpack_fn, num_embeddings, embedding_dim, bit_rate, optimized_qparams): weights = torch.from_numpy((np.random.random_sample(( num_embeddings, embedding_dim)) + 1).astype(np.float32)) - + qtype = torch.quint8 if bit_rate == 8: w_packed = pack_fn(weights) else: w_packed = pack_fn(weights, optimized_qparams=optimized_qparams) w_unpacked = unpack_fn(w_packed) - if bit_rate == 8: + if bit_rate == 8 or bit_rate == 4: # Check numerics of prepack function that accepts qtensor as input. # We use min-max observer to mimic the quantization performed in the original function. obs = PerChannelMinMaxObserver(dtype=torch.quint8, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0) obs(weights) # Get the scale and zero point for the weight tensor qparams = obs.calculate_qparams() - + if bit_rate == 4: + qtype = torch.quint4x2 # Quantize the weights to 8bits - qweight = torch.quantize_per_channel(weights, qparams[0], qparams[1], axis=0, dtype=torch.quint8) + qweight = torch.quantize_per_channel(weights, qparams[0], qparams[1], axis=0, dtype=qtype) real_packed_weight = torch.ops.quantized.embedding_bag_prepack(qweight) self.assertEqual(isinstance(real_packed_weight, torch._C.ScriptObject), True) unpacked_weight = torch.ops.quantized.embedding_bag_unpack(real_packed_weight) From 11c32611d72a679f95fe1fda8dec04ad37ee0ee4 Mon Sep 17 00:00:00 2001 From: Supriya Rao Date: Tue, 6 Oct 2020 21:04:22 -0700 Subject: [PATCH 20/69] [quant] Support 4-bit embedding_bag operators using the dtype quint4x2 (#45752) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45752 Use the torch.quint4x2 dtype to create 4-bit packed tensors in the previous PR. These packed tensors can be directly consumed by the operator. Serialization of the packed tensors is supported using torchbind custom class. Module support will follow in a later PR. Test Plan: python test/test_quantization.py TestEmbeddingBagOps Imported from OSS Reviewed By: jerryzh168 Differential Revision: D24120996 fbshipit-source-id: 2639353b3343ebc69e058b5ba237d3fc56728e1c --- .../quantized/cpu/embedding_packed_params.h | 8 + .../ATen/native/quantized/cpu/fbgemm_utils.h | 8 + .../native/quantized/cpu/qembeddingbag.cpp | 420 ++++++++++-------- aten/src/ATen/native/quantized/library.cpp | 1 + test/quantization/test_quantized_op.py | 36 +- torch/quantization/observer.py | 7 +- 6 files changed, 280 insertions(+), 200 deletions(-) diff --git a/aten/src/ATen/native/quantized/cpu/embedding_packed_params.h b/aten/src/ATen/native/quantized/cpu/embedding_packed_params.h index 921b585b89a7..3327e7d5320f 100644 --- a/aten/src/ATen/native/quantized/cpu/embedding_packed_params.h +++ b/aten/src/ATen/native/quantized/cpu/embedding_packed_params.h @@ -11,6 +11,14 @@ struct EmbeddingPackedParamsBase : public torch::jit::CustomClassHolder { const c10::optional& per_sample_weights_, bool include_last_offset) = 0; + virtual at::Tensor embeddingbag_4bit( + const at::Tensor& indices, + const c10::optional& offsets, + bool sparse, + const c10::optional& per_sample_weights_, + const c10::optional& compressed_indices_mapping, + bool include_last_offset) = 0; + virtual at::Tensor unpack() = 0; virtual int64_t bit_rate() const = 0; diff --git a/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h b/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h index 983398022353..6d74cc40c215 100644 --- a/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h +++ b/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h @@ -333,4 +333,12 @@ struct CAFFE2_API PackedEmbeddingBagWeight : public EmbeddingPackedParamsBase { bool sparse, const c10::optional& per_sample_weights_, bool include_last_offset) override; + + at::Tensor embeddingbag_4bit( + const at::Tensor& indices, + const c10::optional& offsets, + bool sparse, + const c10::optional& per_sample_weights_, + const c10::optional& compressed_indices_mapping, + bool include_last_offset) override; }; diff --git a/aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp b/aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp index cb82d9aee469..6a1359ee2ad3 100644 --- a/aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp +++ b/aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp @@ -11,6 +11,193 @@ torch::class_ register_embedding_params(); +namespace { +at::Tensor embedding_bag_4bit_helper( + const at::Tensor& weight, + const at::Tensor& indices, + const c10::optional& offsets_in, + bool sparse, + const c10::optional& per_sample_weights_, + const c10::optional& compressed_indices_mapping, + bool include_last_offset) { + TORCH_CHECK( + offsets_in.has_value(), + "embedding_bag_4bit_rowwise_offsets expects offsets to be set"); + + TORCH_CHECK(weight.dim() == 2); + TORCH_CHECK(indices.dim() == 1); + + auto offsets = offsets_in.value(); + TORCH_CHECK(offsets.dim() == 1); + + // FBGEMM expects the offsets to be of int type. + at::Tensor offsets_new = offsets.toType(at::ScalarType::Int); + + auto offsets_data = offsets_new.data_ptr(); + const auto weight_data = weight.data_ptr(); + auto weight_contig = weight.contiguous(); + uint8_t* input_data = weight_contig.data_ptr(); + + // Get compressed indices for sparse op. + int32_t* compressed_indices_mapping_data = nullptr; + int compressed_index_size = 0; + if (sparse) { + compressed_index_size = compressed_indices_mapping.value().numel(); + compressed_indices_mapping_data = + compressed_indices_mapping.value().data_ptr(); + } + + const auto indices_data = indices.data_ptr(); + const int64_t N = weight.size(0); + const int64_t weight_size = weight.size(1); + const int64_t D = + (weight_size - 4) * 2; // NB: 2-byte fp16 scale and 2-byte zero_offset + const int64_t M = offsets.size(0); + + int64_t output_size = M - 1; + std::vector offsets_include_last_val; + if (!include_last_offset) { + output_size = M; + offsets_include_last_val.resize(M + 1); + // Avoid `null pointer passed as argument 2` ASAN violation when ofests + // tensor is empty. + if (M > 0) { + std::memcpy( + offsets_include_last_val.data(), offsets_data, sizeof(int) * M); + } + offsets_include_last_val[M] = indices.numel(); + offsets_data = offsets_include_last_val.data(); + } + + const std::vector shape = {output_size, D}; + auto output = at::empty(shape, weight.options().dtype(at::kFloat)); + auto* output_data = output.data_ptr(); + const int64_t block_size = output.size(1); + TORCH_CHECK(block_size % 2 == 0, "block size must be divisible by 2"); + const int index_size = indices.numel(); + constexpr int prefetch_distance = 16; +#ifdef USE_FBGEMM + if (!sparse) { + // Generate the fbgemm kernel + auto kernel_64_ = fbgemm::GenerateEmbeddingSpMDMNBit( + /*bit rate=*/4, + /*block size=*/block_size, + /*has weights=*/per_sample_weights_.has_value(), + /*normalize_by_lengths=*/false, + /*prefetch distance=*/prefetch_distance, + /*is_weight_positional=*/false, + /*use_offsets=*/true); + + bool success = kernel_64_( + /*output_size=*/output_size, + /*index_size=*/index_size, + /*data_size=*/N, + /*input=*/input_data, + /*indices=*/indices_data, + /*offsets=*/offsets_data, + /*weights=*/ + per_sample_weights_.has_value() + ? per_sample_weights_.value().data_ptr() + : nullptr, + /*output=*/output_data); + + TORCH_CHECK( + success, + "FBGEMM GenerateEmbeddingSpMDMNBit kernel failed for 4-bit input"); + } else { + auto kernel_64_ = + fbgemm::GenerateEmbeddingSpMDMNBitRowWiseSparse( + /*bit rate=*/4, + /*block_size=*/block_size, + /*has weights=*/per_sample_weights_.has_value(), + /*normalize_by_lengths=*/false, + /*prefetch distance*/ prefetch_distance, + /*is_weight_positional*/ false, + /*use_offsets*/ true); + bool success = kernel_64_( + /*output_size=*/output_size, + /*index_size=*/index_size, + /*data_size=*/compressed_index_size, + /*input=*/input_data, + /*indices=*/indices_data, + /*offsets=*/offsets_data, + /*weights=*/ + per_sample_weights_.has_value() + ? per_sample_weights_.value().data_ptr() + : nullptr, + /*output=*/output_data, + /*compressed_indices_table=*/compressed_indices_mapping_data); + TORCH_CHECK( + success, + "FBGEMM GenerateEmbeddingSpMDMNBitRowWiseSparse kernel failed for 4-bit input"); + } +#else + + auto accessor = offsets.accessor(); + std::vector lengths_data; + + int64_t lower = accessor[0]; + for (int64_t i = 1; i < offsets.numel(); ++i) { + lengths_data.push_back(accessor[i] - lower); + lower = accessor[i]; + } + if (!include_last_offset) { + lengths_data.push_back(indices.numel() - lower); + } + + int64_t current = 0; + float* per_sample_weights_data; + if (per_sample_weights_.has_value()) { + per_sample_weights_data = per_sample_weights_.value().data_ptr(); + } + for (int m = 0; m < output_size; ++m) { + memset(output_data, 0, block_size * sizeof(float)); + TORCH_CHECK( + current + lengths_data[m] <= index_size, + "Expect the lengths data to be less than indices size"); + + for (int i = 0; i < lengths_data[m]; ++i, ++current) { + int64_t idx; + if (!sparse) { + idx = indices_data[current]; + TORCH_CHECK((idx >= 0 && idx < N), "Invalid indices data"); + } else { + int64_t uncompressed_idx = indices_data[current]; + TORCH_CHECK( + uncompressed_idx >= 0 && uncompressed_idx < compressed_index_size, + "Invalid indices data for Sparse Op.") + idx = compressed_indices_mapping_data[uncompressed_idx]; + if (idx == -1) { + continue; + } + } + const at::Half* scale_bias = reinterpret_cast( + input_data + (idx + 1) * weight_size - 2 * sizeof(at::Half)); + + float weight_val = 1.0f; + if (per_sample_weights_.has_value()) { + weight_val = per_sample_weights_data[current]; + } + const float scale = weight_val * scale_bias[0]; + const float bias = weight_val * scale_bias[1]; + + for (int j = 0; j < block_size; ++j) { + uint8_t quantized = + input_data[idx * weight_size + j / /*NUM_ELEM_PER_BYTE*/ 2]; + quantized >>= (j % 2) * 4; + quantized &= (1 << 4) - 1; + + output_data[j] = fma(scale, quantized, output_data[j] + bias); + } + } // for each i + output_data += block_size; + } // for each m + +#endif + return output; +} +} // namespace + at::Tensor PackedEmbeddingBagWeight::embeddingbag_byte( const at::Tensor& indices, const c10::optional& offsets_in, @@ -109,6 +296,23 @@ at::Tensor PackedEmbeddingBagWeight::embeddingbag_byte( return output; } +at::Tensor PackedEmbeddingBagWeight::embeddingbag_4bit( + const at::Tensor& indices, + const c10::optional& offsets_in, + bool sparse, + const c10::optional& per_sample_weights_, + const c10::optional& compressed_indices_mapping, + bool include_last_offset) { + return embedding_bag_4bit_helper( + packed_w, + indices, + offsets_in, + sparse, + per_sample_weights_, + compressed_indices_mapping, + include_last_offset); +} + namespace at { namespace native { namespace { @@ -123,7 +327,7 @@ Tensor embedding_bag_byte_rowwise_offsets( const c10::optional& per_sample_weights_, bool include_last_offset) { TORCH_CHECK(weight.scalar_type() == at::kByte); - TORCH_CHECK(weight.ndimension() == 2); + TORCH_CHECK(weight.dim() == 2); TORCH_CHECK( offsets_in.has_value(), "embedding_bag_byte_rowwise_offsets expects offsets to be set"); @@ -224,185 +428,14 @@ Tensor embedding_bag_4bit_rowwise_offsets( const c10::optional& per_sample_weights_, const c10::optional& compressed_indices_mapping, bool include_last_offset) { - TORCH_CHECK( - offsets_in.has_value(), - "embedding_bag_4bit_rowwise_offsets expects offsets to be set"); - - TORCH_CHECK(weight.ndimension() == 2); - TORCH_CHECK(indices.ndimension() == 1); - - auto offsets = offsets_in.value(); - TORCH_CHECK(offsets.ndimension() == 1); - - // FBGEMM expects the offsets to be of int type. - at::Tensor offsets_new = offsets.toType(ScalarType::Int); - - auto offsets_data = offsets_new.data_ptr(); - const auto weight_data = weight.data_ptr(); - uint8_t* input_data = nullptr; - if (!weight.is_contiguous()) { - auto weight_contig = weight.contiguous(); - input_data = weight_contig.data_ptr(); - } else { - input_data = weight.data_ptr(); - } - - // Get compressed indices for sparse op. - int32_t* compressed_indices_mapping_data = nullptr; - int compressed_index_size = 0; - if (sparse) { - compressed_index_size = compressed_indices_mapping.value().numel(); - compressed_indices_mapping_data = - compressed_indices_mapping.value().data_ptr(); - } - - const auto indices_data = indices.data_ptr(); - const int64_t N = weight.size(0); - const int64_t D = - (weight.size(1) - 4) * 2; // NB: 2-byte fp16 scale and 2-byte zero_offset - const int64_t M = offsets.size(0); - - int64_t output_size = M - 1; - std::vector offsets_include_last_val; - if (!include_last_offset) { - output_size = M; - offsets_include_last_val.resize(M + 1); - // Avoid `null pointer passed as argument 2` ASAN violation when ofests - // tensor is empty. - if (M > 0) { - std::memcpy( - offsets_include_last_val.data(), offsets_data, sizeof(int) * M); - } - offsets_include_last_val[M] = indices.numel(); - offsets_data = offsets_include_last_val.data(); - } - - const std::vector shape = {output_size, D}; - auto output = at::empty(shape, weight.options().dtype(at::kFloat)); - auto* output_data = output.data_ptr(); - const int64_t block_size = output.size(1); - TORCH_CHECK(block_size % 2 == 0, "block size must be divisible by 2"); - const int index_size = indices.numel(); - constexpr int prefetch_distance = 16; -#ifdef USE_FBGEMM - if (!sparse) { - // Generate the fbgemm kernel - auto kernel_64_ = fbgemm::GenerateEmbeddingSpMDMNBit( - /*bit rate=*/4, - /*block size=*/block_size, - /*has weights=*/per_sample_weights_.has_value(), - /*normalize_by_lengths=*/false, - /*prefetch distance=*/prefetch_distance, - /*is_weight_positional=*/false, - /*use_offsets=*/true); - - bool success = kernel_64_( - /*output_size=*/output_size, - /*index_size=*/index_size, - /*data_size=*/N, - /*input=*/input_data, - /*indices=*/indices_data, - /*offsets=*/offsets_data, - /*weights=*/ - per_sample_weights_.has_value() - ? per_sample_weights_.value().data_ptr() - : nullptr, - /*output=*/output_data); - - TORCH_CHECK( - success, - "FBGEMM GenerateEmbeddingSpMDMNBit kernel failed for 4-bit input"); - } else { - auto kernel_64_ = - fbgemm::GenerateEmbeddingSpMDMNBitRowWiseSparse( - /*bit rate=*/4, - /*block_size=*/block_size, - /*has weights=*/per_sample_weights_.has_value(), - /*normalize_by_lengths=*/false, - /*prefetch distance*/ prefetch_distance, - /*is_weight_positional*/ false, - /*use_offsets*/ true); - bool success = kernel_64_( - /*output_size=*/output_size, - /*index_size=*/index_size, - /*data_size=*/compressed_index_size, - /*input=*/weight_data, - /*indices=*/indices_data, - /*offsets=*/offsets_data, - /*weights=*/ - per_sample_weights_.has_value() - ? per_sample_weights_.value().data_ptr() - : nullptr, - /*output=*/output_data, - /*compressed_indices_table=*/compressed_indices_mapping_data); - TORCH_CHECK( - success, - "FBGEMM GenerateEmbeddingSpMDMNBitRowWiseSparse kernel failed for 4-bit input"); - } -#else - - auto accessor = offsets.accessor(); - std::vector lengths_data; - - int64_t lower = accessor[0]; - for (int64_t i = 1; i < offsets.numel(); ++i) { - lengths_data.push_back(accessor[i] - lower); - lower = accessor[i]; - } - if (!include_last_offset) { - lengths_data.push_back(indices.numel() - lower); - } - - int64_t current = 0; - float* per_sample_weights_data; - if (per_sample_weights_.has_value()) { - per_sample_weights_data = per_sample_weights_.value().data_ptr(); - } - for (int m = 0; m < output_size; ++m) { - memset(output_data, 0, block_size * sizeof(float)); - TORCH_CHECK( - current + lengths_data[m] <= index_size, - "Expect the lengths data to be less than indices size"); - - for (int i = 0; i < lengths_data[m]; ++i, ++current) { - int64_t idx; - if (!sparse) { - idx = indices_data[current]; - TORCH_CHECK((idx >= 0 && idx < N), "Invalid indices data"); - } else { - int64_t uncompressed_idx = indices_data[current]; - TORCH_CHECK( - uncompressed_idx >= 0 && uncompressed_idx < compressed_index_size, - "Invalid indices data for Sparse Op.") - idx = compressed_indices_mapping_data[uncompressed_idx]; - if (idx == -1) { - continue; - } - } - const at::Half* scale_bias = reinterpret_cast( - input_data + (idx + 1) * weight.size(1) - 2 * sizeof(at::Half)); - - float weight_val = 1.0f; - if (per_sample_weights_.has_value()) { - weight_val = per_sample_weights_data[current]; - } - const float scale = weight_val * scale_bias[0]; - const float bias = weight_val * scale_bias[1]; - - for (int j = 0; j < block_size; ++j) { - uint8_t quantized = - input_data[idx * weight.size(1) + j / /*NUM_ELEM_PER_BYTE*/ 2]; - quantized >>= (j % 2) * 4; - quantized &= (1 << 4) - 1; - - output_data[j] = fma(scale, quantized, output_data[j] + bias); - } - } // for each i - output_data += block_size; - } // for each m - -#endif - return output; + return embedding_bag_4bit_helper( + weight, + indices, + offsets_in, + sparse, + per_sample_weights_, + compressed_indices_mapping, + include_last_offset); } template @@ -421,6 +454,14 @@ class QEmbeddingBag final { if (bit_rate == 8) { return packed_weight->embeddingbag_byte( indices, offsets, sparse, per_sample_weights_, include_last_offset); + } else if (bit_rate == 4) { + return packed_weight->embeddingbag_4bit( + indices, + offsets, + sparse, + per_sample_weights_, + compressed_indices_mapping, + include_last_offset); } else { TORCH_INTERNAL_ASSERT( "Currently only support 8-bit embedding_bag quantization"); @@ -451,12 +492,23 @@ class QEmbedding final { TORCH_LIBRARY_IMPL(quantized, CPU, m) { // Function that works on TorchBind packed weights. - m.impl(TORCH_SELECTIVE_NAME("quantized::embedding_bag_byte"), TORCH_FN(QEmbeddingBag<8>::run)); - m.impl(TORCH_SELECTIVE_NAME("quantized::embedding_byte"), TORCH_FN(QEmbedding<8>::run)); + m.impl( + TORCH_SELECTIVE_NAME("quantized::embedding_bag_byte"), + TORCH_FN(QEmbeddingBag<8>::run)); + m.impl( + TORCH_SELECTIVE_NAME("quantized::embedding_bag_4bit"), + TORCH_FN(QEmbeddingBag<4>::run)); + m.impl( + TORCH_SELECTIVE_NAME("quantized::embedding_byte"), + TORCH_FN(QEmbedding<8>::run)); // Functions that work on at::Tensor packed weight. - m.impl(TORCH_SELECTIVE_NAME("quantized::embedding_bag_byte_rowwise_offsets"), embedding_bag_byte_rowwise_offsets); - m.impl(TORCH_SELECTIVE_NAME("quantized::embedding_bag_4bit_rowwise_offsets"), embedding_bag_4bit_rowwise_offsets); + m.impl( + TORCH_SELECTIVE_NAME("quantized::embedding_bag_byte_rowwise_offsets"), + embedding_bag_byte_rowwise_offsets); + m.impl( + TORCH_SELECTIVE_NAME("quantized::embedding_bag_4bit_rowwise_offsets"), + embedding_bag_4bit_rowwise_offsets); } } // namespace } // namespace native diff --git a/aten/src/ATen/native/quantized/library.cpp b/aten/src/ATen/native/quantized/library.cpp index dceb06b05d4a..39e7b03c140a 100644 --- a/aten/src/ATen/native/quantized/library.cpp +++ b/aten/src/ATen/native/quantized/library.cpp @@ -114,6 +114,7 @@ TORCH_LIBRARY(quantized, m) { m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_byte_rowwise_offsets(Tensor weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False) -> Tensor")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_4bit_rowwise_offsets(Tensor weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_byte(__torch__.torch.classes.quantized.EmbeddingPackedParamsBase weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_4bit(__torch__.torch.classes.quantized.EmbeddingPackedParamsBase weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_byte(__torch__.torch.classes.quantized.EmbeddingPackedParamsBase weight, Tensor indices, bool sparse=False) -> Tensor")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::celu(Tensor self, float output_scale, int output_zero_point, Scalar alpha=1) -> Tensor")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::hardswish(Tensor input, float output_scale, int output_zero_point) -> Tensor")); diff --git a/test/quantization/test_quantized_op.py b/test/quantization/test_quantized_op.py index d11005b51ddd..8d0a9d7063cc 100644 --- a/test/quantization/test_quantized_op.py +++ b/test/quantization/test_quantized_op.py @@ -2964,22 +2964,30 @@ def get_reference_result( torch.testing.assert_allclose(reference_result, result, atol=atol, rtol=rtol) - if bit_rate == 8: + + if bit_rate == 8 or bit_rate == 4: # Test operator that accepts TorchBind packed weights. - obs = PerChannelMinMaxObserver(dtype=torch.quint8, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0) + if bit_rate == 4: + qdtype = torch.quint4x2 + op = torch.ops.quantized.embedding_bag_4bit + else: + qdtype = torch.quint8 + op = torch.ops.quantized.embedding_bag_byte + obs = PerChannelMinMaxObserver(dtype=qdtype, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0) obs(weights) # Get the scale and zero point for the weight tensor qparams = obs.calculate_qparams() # Quantize the weights to 8bits - qweight = torch.quantize_per_channel(weights, qparams[0], qparams[1], axis=0, dtype=torch.quint8) + qweight = torch.quantize_per_channel(weights, qparams[0], qparams[1], axis=0, dtype=qdtype) packed_weight = torch.ops.quantized.embedding_bag_prepack(qweight) - result = torch.ops.quantized.embedding_bag_byte(packed_weight, indices, offsets, mode=0, - per_sample_weights=per_sample_weights, - include_last_offset=include_last_offset) + result = op(packed_weight, indices, offsets, mode=0, + per_sample_weights=per_sample_weights, + include_last_offset=include_last_offset) torch.testing.assert_allclose(reference_result, result, atol=atol, rtol=rtol) + """ Tests the correctness of the embedding_bag_8bit quantized operator """ @skipIfNoFBGEMM @given(num_embeddings=st.integers(10, 100), @@ -2987,10 +2995,10 @@ def get_reference_result( num_offsets=st.integers(1, 20), enable_per_sample_weights=st.booleans(), include_last_offset=st.booleans()) - def test_embedding_bag_byte_rowwise_offsets(self, num_embeddings, - embedding_dim, num_offsets, - enable_per_sample_weights, - include_last_offset): + def test_embedding_bag_byte(self, num_embeddings, + embedding_dim, num_offsets, + enable_per_sample_weights, + include_last_offset): self.embedding_bag_rowwise_offsets_run( 8, num_embeddings, embedding_dim, num_offsets, enable_per_sample_weights, include_last_offset, @@ -3002,10 +3010,10 @@ def test_embedding_bag_byte_rowwise_offsets(self, num_embeddings, num_offsets=st.integers(1, 20), enable_per_sample_weights=st.booleans(), include_last_offset=st.booleans()) - def test_embedding_bag_4bit_rowwise_offsets(self, num_embeddings, - embedding_dim, num_offsets, - enable_per_sample_weights, - include_last_offset): + def test_embedding_bag_4bit(self, num_embeddings, + embedding_dim, num_offsets, + enable_per_sample_weights, + include_last_offset): self.embedding_bag_rowwise_offsets_run(4, num_embeddings, embedding_dim, num_offsets, enable_per_sample_weights, diff --git a/torch/quantization/observer.py b/torch/quantization/observer.py index 5c8257d213e1..4dac8fb68429 100644 --- a/torch/quantization/observer.py +++ b/torch/quantization/observer.py @@ -136,7 +136,8 @@ def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine, assert self.dtype in ( torch.qint8, torch.quint8, - ), "Default Observer only works for qint8 and quint8 data type" + torch.quint4x2, + ), "Default Observer only works for qint8, quint8 and quint4x2 data type" self.has_customized_qrange = (quant_min is not None) and (quant_max is not None) if self.has_customized_qrange: self._validate_qmin_qmax(quant_min, quant_max) @@ -208,11 +209,13 @@ def _calculate_qmin_qmax(self) -> Tuple[int, int]: quant_min, quant_max = -64, 63 else: quant_min, quant_max = -128, 127 - else: + elif self.dtype == torch.quint8: if self.reduce_range: quant_min, quant_max = 0, 127 else: quant_min, quant_max = 0, 255 + else: + quant_min, quant_max = 0, 15 return quant_min, quant_max @torch.jit.export From 43dc7ef9335158fbdb124e5fc0952789e528d06e Mon Sep 17 00:00:00 2001 From: Supriya Rao Date: Tue, 6 Oct 2020 21:04:22 -0700 Subject: [PATCH 21/69] [quant] Support for 4-bit quantized EmbeddingBag module (#45865) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45865 Test Plan: python test/test_quantization.py TestPostTrainingStatic.test_quantized_embedding_bag python test/test_quantization.py TestStaticQuantizedModule.test_embedding_bag_api Imported from OSS Reviewed By: jerryzh168 Differential Revision: D24120995 fbshipit-source-id: c55fc6b2cfd683d14d2a05be7c04f787fdf8cc79 --- test/quantization/test_quantize.py | 66 +++++++++++-------- test/quantization/test_quantized_module.py | 56 +++++++++------- torch/nn/quantized/modules/embedding_ops.py | 33 ++++++---- torch/nn/quantized/modules/utils.py | 2 +- .../testing/_internal/common_quantization.py | 15 +++-- 5 files changed, 102 insertions(+), 70 deletions(-) diff --git a/test/quantization/test_quantize.py b/test/quantization/test_quantize.py index fb2f57282d79..d2e8b64f4601 100644 --- a/test/quantization/test_quantize.py +++ b/test/quantization/test_quantize.py @@ -25,6 +25,9 @@ float_qparams_dynamic_qconfig, register_observed_custom_module_mapping, register_quantized_custom_module_mapping, + PerChannelMinMaxObserver, + QConfigDynamic, + default_dynamic_quant_observer ) from torch.testing._internal.common_quantization import ( @@ -538,42 +541,49 @@ def test_quantized_embedding_bag(self): r""" Test the post-training quantization flow, serialization and scripting of embedding_bag modules """ - model = EmbeddingBagModule().eval() indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8, 3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3]) offsets = torch.tensor([0, 19, 20, 28, 28, 32]) weights = torch.randn(10, 12, dtype=torch.float32) - model.qconfig = float_qparams_dynamic_qconfig - prepare(model, inplace=True) - quantized_model = convert(model) - - per_sample_weights = torch.from_numpy(np.random.uniform( - low=0.01, high=0.5, size=[len(indices)]).astype(np.float32)) - - # Test to make sure module is quantized correctly. - self.assertTrue('QuantizedEmbeddingBag' in str(quantized_model)) - self.checkDynamicQuantizedModule(quantized_model.emb, torch.nn.quantized.EmbeddingBag, torch.quint8) - self.checkScriptable(quantized_model, [[indices, offsets, per_sample_weights]], check_save_load=True) + for dtype in [torch.quint8, torch.quint4x2]: + model = EmbeddingBagModule().eval() + float_qparams_observer = PerChannelMinMaxObserver.with_args(dtype=dtype, + qscheme=torch.per_channel_affine_float_qparams, + ch_axis=0) + float_qparams_qconfig = QConfigDynamic(activation=default_dynamic_quant_observer, + weight=float_qparams_observer) + model.qconfig = float_qparams_qconfig - class EmbeddingBagWithLinear(torch.nn.Module): - def __init__(self): - super().__init__() - self.emb = torch.nn.EmbeddingBag(num_embeddings=10, embedding_dim=12, - include_last_offset=True, scale_grad_by_freq=False, mode='sum') - self.fc = torch.nn.Linear(5, 5) + prepare(model, inplace=True) + quantized_model = convert(model) - def forward(self, indices, offsets, per_sample_weights, linear_in): - return self.emb(indices, offsets, per_sample_weights), self.fc(linear_in) + per_sample_weights = torch.from_numpy(np.random.uniform( + low=0.01, high=0.5, size=[len(indices)]).astype(np.float32)) - # Test quantization of embedding_bag layer only - model = EmbeddingBagWithLinear().eval() - model.emb.qconfig = float_qparams_dynamic_qconfig - prepare(model, inplace=True) - quantized_model = convert(model) + # Test to make sure module is quantized correctly. + self.assertTrue('QuantizedEmbeddingBag' in str(quantized_model)) + self.checkDynamicQuantizedModule(quantized_model.emb, torch.nn.quantized.EmbeddingBag, torch.quint8) + self.checkScriptable(quantized_model, [[indices, offsets, per_sample_weights]], check_save_load=True) - self.assertTrue('QuantizedEmbeddingBag' in str(quantized_model)) - self.checkLinear(model.fc) - self.checkDynamicQuantizedModule(quantized_model.emb, torch.nn.quantized.EmbeddingBag, torch.quint8) + class EmbeddingBagWithLinear(torch.nn.Module): + def __init__(self): + super().__init__() + self.emb = torch.nn.EmbeddingBag(num_embeddings=10, embedding_dim=12, + include_last_offset=True, scale_grad_by_freq=False, mode='sum') + self.fc = torch.nn.Linear(5, 5) + + def forward(self, indices, offsets, per_sample_weights, linear_in): + return self.emb(indices, offsets, per_sample_weights), self.fc(linear_in) + + # Test quantization of embedding_bag layer only + model2 = EmbeddingBagWithLinear().eval() + model2.emb.qconfig = float_qparams_qconfig + prepare(model2, inplace=True) + quantized_model = convert(model2) + + self.assertTrue('QuantizedEmbeddingBag' in str(quantized_model)) + self.checkLinear(model2.fc) + self.checkDynamicQuantizedModule(quantized_model.emb, torch.nn.quantized.EmbeddingBag, torch.quint8) @skipIfNoFBGEMM def test_custom_module_class(self): diff --git a/test/quantization/test_quantized_module.py b/test/quantization/test_quantized_module.py index fabee5615cc8..f5c3a8e3e8d5 100644 --- a/test/quantization/test_quantized_module.py +++ b/test/quantization/test_quantized_module.py @@ -7,7 +7,8 @@ import torch.quantization from torch.quantization import ( - default_float_qparams_observer + default_float_qparams_observer, + PerChannelMinMaxObserver ) from torch.testing._internal.common_quantization import ( QuantizationTestCase, @@ -742,7 +743,7 @@ def test_embedding_api(self, num_embeddings, embedding_dim, set_qconfig): w_packed = qemb._packed_params._packed_weight module_out = qemb(indices) - # Call the qembedding_bag operator directly + # Call the qembedding operator directly ref = torch.ops.quantized.embedding_byte(w_packed, indices, sparse=False) self.assertEqual(module_out, ref) self.checkEmbeddingSerialization(qemb, num_embeddings, embedding_dim, indices, None, set_qconfig=False, is_emb_bag=False) @@ -758,6 +759,7 @@ def test_embedding_api(self, num_embeddings, embedding_dim, set_qconfig): def test_embedding_bag_api(self, num_embeddings, embedding_dim, num_offsets, set_qconfig): r"""Test execution and serialization for dynamic quantized embedding_bag modules on int8 """ + num_lengths = np.random.randint(1, 6) lengths = np.random.randint(0, 21, size=num_lengths).astype(np.int32) num_indices = np.sum(lengths) @@ -768,28 +770,36 @@ def test_embedding_bag_api(self, num_embeddings, embedding_dim, num_offsets, set offsets = torch.cat((offsets, torch.tensor([indices.size(0)], dtype=torch.long)), 0) weights = torch.from_numpy((np.random.random_sample((num_embeddings, embedding_dim)) + 1).astype(np.float32)) - obs = default_float_qparams_observer() - obs(weights) - # Get the scale and zero point for the weight tensor - qparams = obs.calculate_qparams() - # Quantize the weights to 8bits - qweight = torch.quantize_per_channel(weights, qparams[0], qparams[1], axis=0, dtype=torch.quint8) - qemb = nnq.EmbeddingBag(num_embeddings=num_embeddings, embedding_dim=embedding_dim, - include_last_offset=True, mode='sum', _weight=qweight) - qemb(indices, offsets) - - # Ensure the module has the correct weights - self.assertEqual(qweight, qemb.weight()) - - w_packed = qemb._packed_params._packed_weight - module_out = qemb(indices, offsets) + for qdtype in [torch.quint8, torch.quint4x2]: + obs = PerChannelMinMaxObserver(dtype=qdtype, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0) + obs(weights) + # Get the scale and zero point for the weight tensor + qparams = obs.calculate_qparams() + # Quantize the weights to 8bits + qweight = torch.quantize_per_channel(weights, qparams[0], qparams[1], axis=0, dtype=qdtype) + qemb = nnq.EmbeddingBag(num_embeddings=num_embeddings, embedding_dim=embedding_dim, + include_last_offset=True, mode='sum', _weight=qweight, dtype=qdtype) + qemb(indices, offsets) + + # Ensure the module has the correct weights + self.assertEqual(qweight, qemb.weight()) + + w_packed = qemb._packed_params._packed_weight + module_out = qemb(indices, offsets) + + # Call the qembedding_bag operator directly + if qdtype == torch.quint8: + ref = torch.ops.quantized.embedding_bag_byte(w_packed, indices, offsets, mode=0, + per_sample_weights=None, + include_last_offset=True) + else: + ref = torch.ops.quantized.embedding_bag_4bit(w_packed, indices, offsets, mode=0, + per_sample_weights=None, + include_last_offset=True) - # Call the qembedding_bag operator directly - ref = torch.ops.quantized.embedding_bag_byte(w_packed, indices, offsets, mode=0, - per_sample_weights=None, - include_last_offset=True) - self.assertEqual(module_out, ref) - self.checkEmbeddingSerialization(qemb, num_embeddings, embedding_dim, indices, offsets, set_qconfig, is_emb_bag=True) + self.assertEqual(module_out, ref) + self.checkEmbeddingSerialization(qemb, num_embeddings, embedding_dim, indices, + offsets, set_qconfig, is_emb_bag=True, dtype=qdtype) class TestDynamicQuantizedModule(QuantizationTestCase): @given( diff --git a/torch/nn/quantized/modules/embedding_ops.py b/torch/nn/quantized/modules/embedding_ops.py index 278eeed2ca9f..8c660fcb73a0 100644 --- a/torch/nn/quantized/modules/embedding_ops.py +++ b/torch/nn/quantized/modules/embedding_ops.py @@ -11,31 +11,31 @@ class EmbeddingPackedParams(torch.nn.Module): def __init__(self, num_embeddings, embedding_dim, dtype=torch.quint8): super(EmbeddingPackedParams, self).__init__() self.dtype = dtype - if self.dtype == torch.quint8: + if self.dtype in [torch.quint8, torch.quint4x2]: scales = torch.ones(num_embeddings, dtype=torch.float) zero_points = torch.zeros(num_embeddings, dtype=torch.float) wq = torch._empty_per_channel_affine_quantized([num_embeddings, embedding_dim], scales=scales, zero_points=zero_points, - axis=0, dtype=torch.quint8) + axis=0, dtype=self.dtype) self.set_weight(wq) else: - raise RuntimeError('Unsupported dtype on quantized embedding!') + raise NotImplementedError('Unsupported dtype on quantized embedding! Supports quint8 and quint4x2.') @torch.jit.export def set_weight(self, weight): # type: (torch.Tensor) -> None - if self.dtype == torch.quint8: + if self.dtype in [torch.quint8, torch.quint4x2]: self._packed_weight = torch.ops.quantized.embedding_bag_prepack(weight) else: - raise RuntimeError('Unsupported dtype on quantized embedding!') + raise NotImplementedError('Unsupported dtype for quantized embedding prepack! Supports quint8 and quint4x2.') @torch.jit.export def _weight(self): - if self.dtype == torch.quint8: + if self.dtype in [torch.quint8, torch.quint4x2]: return torch.ops.quantized.embedding_bag_unpack(self._packed_weight) else: - raise RuntimeError('Unsupported dtype on quantized embedding!') + raise NotImplementedError('Unsupported dtype for quantized embedding unpack! Supports quint8 and quint4x2.') def forward(self, x): return x @@ -192,17 +192,23 @@ def __init__(self, num_embeddings: int, embedding_dim: int, max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False, mode: str = 'sum', sparse: bool = False, _weight: Optional[Tensor] = None, include_last_offset: bool = False, dtype=torch.quint8) -> None: - super(EmbeddingBag, self).__init__(num_embeddings, embedding_dim, _weight=_weight) + super(EmbeddingBag, self).__init__(num_embeddings, embedding_dim, _weight=_weight, dtype=dtype) self.mode = mode self.sparse = sparse self.include_last_offset = include_last_offset + self.dtype = dtype def forward(self, indices: Tensor, offsets: Optional[Tensor] = None, per_sample_weights: Optional[Tensor] = None, compressed_indices_mapping: Optional[Tensor] = None) -> Tensor: - return torch.ops.quantized.embedding_bag_byte(self._packed_params._packed_weight, indices, offsets, False, 0, - self.sparse, per_sample_weights, compressed_indices_mapping, - self.include_last_offset) + if self.dtype == torch.quint4x2: + return torch.ops.quantized.embedding_bag_4bit(self._packed_params._packed_weight, indices, offsets, False, 0, + self.sparse, per_sample_weights, compressed_indices_mapping, + self.include_last_offset) + else: + return torch.ops.quantized.embedding_bag_byte(self._packed_params._packed_weight, indices, offsets, False, 0, + self.sparse, per_sample_weights, compressed_indices_mapping, + self.include_last_offset) def _get_name(self): return 'QuantizedEmbeddingBag' @@ -226,13 +232,14 @@ def from_float(cls, mod): dtype = weight_observer.dtype - assert dtype == torch.quint8, 'The only supported dtype for nnq.EmbeddingBag is torch.quint8' + assert dtype == torch.quint8 or dtype == torch.quint4x2, \ + 'The only supported dtype for nnq.EmbeddingBag is torch.quint8 and torch.quint4x2' # Run the observer to calculate qparams. weight_observer(mod.weight) qweight = _quantize_weight(mod.weight.float(), weight_observer) # Create quantized EmbeddingBag module and pass in the quantized weight - qembedding_bag = EmbeddingBag(mod.num_embeddings, mod.embedding_dim) + qembedding_bag = EmbeddingBag(mod.num_embeddings, mod.embedding_dim, dtype=dtype) qembedding_bag.set_weight(qweight) return qembedding_bag diff --git a/torch/nn/quantized/modules/utils.py b/torch/nn/quantized/modules/utils.py index d531983a6ff5..9d6e93f9d2fe 100644 --- a/torch/nn/quantized/modules/utils.py +++ b/torch/nn/quantized/modules/utils.py @@ -17,7 +17,7 @@ def _quantize_weight(float_wt, observer): elif observer.qscheme in [torch.per_channel_affine_float_qparams]: qweight = torch.quantize_per_channel( float_wt, - wt_scale.to(torch.float), wt_zp.to(torch.float), observer.ch_axis, torch.quint8) + wt_scale.to(torch.float), wt_zp.to(torch.float), observer.ch_axis, observer.dtype) else: raise ValueError("Unexpected qscheme " + observer.qscheme) return qweight diff --git a/torch/testing/_internal/common_quantization.py b/torch/testing/_internal/common_quantization.py index 468fd9cfdc81..73efb5181db7 100644 --- a/torch/testing/_internal/common_quantization.py +++ b/torch/testing/_internal/common_quantization.py @@ -12,7 +12,7 @@ from torch.quantization import QuantWrapper, QuantStub, DeQuantStub, \ default_qconfig, default_dynamic_qconfig, default_per_channel_qconfig, QConfig, default_observer, default_weight_observer, \ propagate_qconfig_, convert, get_default_qconfig, quantize_dynamic_jit, quantize_jit, float_qparams_dynamic_qconfig, \ - get_default_qat_qconfig + get_default_qat_qconfig, PerChannelMinMaxObserver, default_dynamic_quant_observer, QConfigDynamic from torch.quantization import ( is_custom_module_class, is_observed_custom_module, @@ -667,7 +667,8 @@ def checkGraphModeFxOp(self, model, inputs, quant_type, qgraph_to_check, expected_node, expected_node_occurrence, expected_node_list) - def checkEmbeddingSerialization(self, qemb, num_embeddings, embedding_dim, indices, offsets, set_qconfig, is_emb_bag): + def checkEmbeddingSerialization(self, qemb, num_embeddings, embedding_dim, indices, offsets, + set_qconfig, is_emb_bag, dtype=torch.quint8): # Test serialization of dynamic EmbeddingBag module using state_dict if is_emb_bag: inputs = [indices, offsets] @@ -690,9 +691,9 @@ def checkEmbeddingSerialization(self, qemb, num_embeddings, embedding_dim, indic # Check state dict serialization and torch.save APIs if is_emb_bag: loaded_qemb = nnq.EmbeddingBag(num_embeddings=num_embeddings, embedding_dim=embedding_dim, - include_last_offset=True, mode='sum') + include_last_offset=True, mode='sum', dtype=dtype) else: - loaded_qemb = nnq.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim) + loaded_qemb = nnq.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim, dtype=dtype) self.check_eager_serialization(qemb, loaded_qemb, inputs) loaded_qemb.load_state_dict(loaded_dict) @@ -711,7 +712,11 @@ def checkEmbeddingSerialization(self, qemb, num_embeddings, embedding_dim, indic float_embedding = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim) if set_qconfig: - float_embedding.qconfig = float_qparams_dynamic_qconfig + float_qparams_observer = PerChannelMinMaxObserver.with_args(dtype=dtype, + qscheme=torch.per_channel_affine_float_qparams, + ch_axis=0) + float_embedding.qconfig = QConfigDynamic(activation=default_dynamic_quant_observer, + weight=float_qparams_observer) prepare_dynamic(float_embedding) From 1b31ed3ad60fef347ef8043d72209d370d56251b Mon Sep 17 00:00:00 2001 From: Supriya Rao Date: Tue, 6 Oct 2020 21:04:22 -0700 Subject: [PATCH 22/69] [quant] Refactor qembeddingbag to remove duplicate code (#45881) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45881 Test Plan: python test/test_quantization.py TestQuantizedEmbeddingBagOps Imported from OSS Reviewed By: jerryzh168 Differential Revision: D24127892 fbshipit-source-id: 344ee71d335b8c1d668c647db88775632e099dbd --- .../native/quantized/cpu/qembeddingbag.cpp | 118 ++++-------------- 1 file changed, 27 insertions(+), 91 deletions(-) diff --git a/aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp b/aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp index 6a1359ee2ad3..96453255112d 100644 --- a/aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp +++ b/aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp @@ -196,9 +196,9 @@ at::Tensor embedding_bag_4bit_helper( #endif return output; } -} // namespace -at::Tensor PackedEmbeddingBagWeight::embeddingbag_byte( +at::Tensor embedding_bag_byte_helper( + const at::Tensor& packed_w, const at::Tensor& indices, const c10::optional& offsets_in, bool sparse, @@ -296,6 +296,23 @@ at::Tensor PackedEmbeddingBagWeight::embeddingbag_byte( return output; } +} // namespace + +at::Tensor PackedEmbeddingBagWeight::embeddingbag_byte( + const at::Tensor& indices, + const c10::optional& offsets_in, + bool sparse, + const c10::optional& per_sample_weights_, + bool include_last_offset) { + return embedding_bag_byte_helper( + packed_w, + indices, + offsets_in, + sparse, + per_sample_weights_, + include_last_offset); +} + at::Tensor PackedEmbeddingBagWeight::embeddingbag_4bit( const at::Tensor& indices, const c10::optional& offsets_in, @@ -323,99 +340,18 @@ Tensor embedding_bag_byte_rowwise_offsets( const c10::optional& offsets_in, const bool /* scale_grad_by_freq */, const int64_t /* mode */, - bool /* sparse */, + bool sparse, const c10::optional& per_sample_weights_, bool include_last_offset) { TORCH_CHECK(weight.scalar_type() == at::kByte); TORCH_CHECK(weight.dim() == 2); - TORCH_CHECK( - offsets_in.has_value(), - "embedding_bag_byte_rowwise_offsets expects offsets to be set"); - - auto offsets = offsets_in.value(); - auto offsets_data = offsets.data_ptr(); - const auto weight_data = weight.data_ptr(); - const auto indices_data = indices.data_ptr(); - - const int64_t N = weight.size(0); - const int64_t D = weight.size(1) - 8; // NB: -8 to account for scale and bias - const int64_t M = offsets.size(0); - - int64_t output_size = M - 1; - std::vector offsets_include_last; - - if (!include_last_offset) { - output_size = M; - offsets_include_last.resize(M + 1); - std::memcpy( - offsets_include_last.data(), - offsets.data_ptr(), - sizeof(int64_t) * M); - offsets_include_last[M] = indices.numel(); - offsets_data = offsets_include_last.data(); - } - - std::vector shape = {output_size, D}; - auto output = at::empty(shape, weight.options().dtype(at::kFloat)); - auto* output_data = output.data_ptr(); - -#ifdef USE_FBGEMM - - auto kernel_i8_i64 = - fbgemm::GenerateEmbeddingSpMDM( - /*block_size=*/D, - /*has_weight=*/per_sample_weights_.has_value(), - /*normalize_by_lengths=*/false, - /*prefetch=*/16, // NOLINT(cppcoreguidelines-avoid-magic-numbers) - /*is_weight_positional=*/false, - /*use_offsets=*/true); - - if (weight.is_contiguous()) { - at::parallel_for( - 0, output_size, 1, [&](int64_t start_idx, int64_t end_idx) { - bool success = kernel_i8_i64( - /*output_size=*/end_idx - start_idx, - /*index_size=*/offsets_data[end_idx] - offsets_data[start_idx], - /*data_size=*/N, - /*input=*/weight_data, - /*indices=*/indices_data + offsets_data[start_idx], - /*offsets_or_lengths=*/offsets_data + start_idx, - /*weights=*/ - per_sample_weights_ - ? per_sample_weights_.value().data_ptr() + - offsets_data[start_idx] - : nullptr, - /*out=*/output_data + start_idx * D); - - TORCH_CHECK( - success, - "FBGEMM GenerateEmbeddingSpMDM kernel failed for 8-bit input"); - }); - } else { - auto weight_contig = weight.contiguous(); - const auto weight_data_contig = weight_contig.data_ptr(); - at::parallel_for( - 0, output_size, 1, [&](int64_t start_idx, int64_t end_idx) { - bool success = kernel_i8_i64( - /*output_size=*/end_idx - start_idx, - /*index_size=*/offsets_data[end_idx] - offsets_data[start_idx], - /*data_size=*/N, - /*input=*/weight_data_contig, - /*indices=*/indices_data + offsets_data[start_idx], - /*offsets_or_lengths=*/offsets_data + start_idx, - /*weights=*/ - per_sample_weights_ - ? per_sample_weights_.value().data_ptr() + - offsets_data[start_idx] - : nullptr, - /*out=*/output_data + start_idx * D); - TORCH_CHECK( - success, - "FBGEMM GenerateEmbeddingSpMDM kernel failed for 8-bit input"); - }); - } -#endif - return output; + return embedding_bag_byte_helper( + weight, + indices, + offsets_in, + sparse, + per_sample_weights_, + include_last_offset); } Tensor embedding_bag_4bit_rowwise_offsets( From 8b39498a23e1840658cfb3885afd1f39bb7937f9 Mon Sep 17 00:00:00 2001 From: Peter Bell Date: Tue, 6 Oct 2020 21:51:12 -0700 Subject: [PATCH 23/69] codegen: Allow string arguments to have defaults (#45665) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45665 Fixes #43944 Note that the codegen doesn't use a proper parser so, in the same way as with lists, the string `, ` cannot appear in defaults or it will be interpreted as a splitting point between arguments. Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D24141835 Pulled By: ezyang fbshipit-source-id: 578127861fd2504917f4486c44100491a2c40343 --- aten/src/ATen/native/TestOps.cpp | 8 +++ aten/src/ATen/native/native_functions.yaml | 5 ++ test/test_native_functions.py | 16 ++++++ tools/autograd/gen_python_functions.py | 3 +- tools/codegen/api/cpp.py | 20 ++++++++ tools/codegen/gen.py | 24 +++++++-- torch/csrc/utils/python_arg_parser.cpp | 60 +++++++++++++++++++++- torch/csrc/utils/python_arg_parser.h | 8 ++- 8 files changed, 136 insertions(+), 8 deletions(-) diff --git a/aten/src/ATen/native/TestOps.cpp b/aten/src/ATen/native/TestOps.cpp index c89a7ee02221..0ebdce6795aa 100644 --- a/aten/src/ATen/native/TestOps.cpp +++ b/aten/src/ATen/native/TestOps.cpp @@ -42,5 +42,13 @@ Tensor _test_optional_floatlist( return output; } +// Test default strings can handle escape sequences properly (although commas are broken) +Tensor _test_string_default(const Tensor& dummy, std::string a, std::string b) { + const c10::string_view expect = "\"'\\"; + TORCH_CHECK(a == expect, "Default A failed"); + TORCH_CHECK(b == expect, "Default B failed"); + return dummy; +} + } // namespace native } // namespace at diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index c27cb4083ac2..e64a66a07417 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -8308,3 +8308,8 @@ python_module: nn dispatch: CPU: _test_optional_floatlist + +# Note: this function is only for testing. +- func: _test_string_default(Tensor dummy, str a="\"'\\", str b='"\'\\') -> Tensor + use_c10_dispatcher: full + python_module: nn diff --git a/test/test_native_functions.py b/test/test_native_functions.py index 869c7aad47fb..57d9f89dc341 100644 --- a/test/test_native_functions.py +++ b/test/test_native_functions.py @@ -176,6 +176,22 @@ def fake_module(values, const): self.do_test_optional_filled_intlist_with_module(fake_module) + def test_string_defaults(self): + dummy = torch.rand(1) + fn = torch._C._nn._test_string_default + fn(dummy) + + with self.assertRaisesRegex(RuntimeError, "A"): + fn(dummy, a="") + + with self.assertRaisesRegex(RuntimeError, "B"): + fn(dummy, b="") + + def f(x): + torch._C._nn._test_string_default(x) + scripted_fn = torch.jit.script(f) + scripted_fn(dummy) + if __name__ == '__main__': run_tests() diff --git a/tools/autograd/gen_python_functions.py b/tools/autograd/gen_python_functions.py index eb5de6f75ef5..5a3e99bdf4e4 100644 --- a/tools/autograd/gen_python_functions.py +++ b/tools/autograd/gen_python_functions.py @@ -36,6 +36,7 @@ from .utils import write, is_tensor_method from tools.codegen.code_template import CodeTemplate +from tools.codegen.gen import cpp_string # # declarations blocklist @@ -964,7 +965,7 @@ def method_impl(name, declarations, is_python_method, module): dispatch = [] for i, dictionary in enumerate(grouped): signature = dictionary['signature'] - signatures.append('"{}",'.format(signature)) + signatures.append(f'{cpp_string(str(signature))},') overload_index = i if not is_singleton else None dispatch.append(emit_dispatch_case(overload_index, dictionary, is_python_method)) diff --git a/tools/codegen/api/cpp.py b/tools/codegen/api/cpp.py index 538ba3596c7d..566d8f8265a9 100644 --- a/tools/codegen/api/cpp.py +++ b/tools/codegen/api/cpp.py @@ -161,6 +161,26 @@ def returns_type(rs: Sequence[Return]) -> str: def default_expr(d: str, t: Type) -> str: if d == 'None' and str(t) == 'Tensor?': return '{}' + if isinstance(t, BaseType) and t.name is BaseTy.str: + # Schema allows single quotes but C++ needs double + if len(d) >= 2 and d[0] == "'" and d[-1] == "'": + s = '' + i = 1 + while i + 1 < len(d): + if d[i] != '\\': + if d[i] == '"': + s += '\\"' + else: + s += d[i] + i += 1 + else: + if d[i + 1] == "'": + s += "'" + else: + s += d[i:i + 2] + i += 2 + + return f'"{s}"' return JIT_TO_CPP_DEFAULT.get(d, d) # Convert an argument into its C++ API form diff --git a/tools/codegen/gen.py b/tools/codegen/gen.py index 48a2b3f56702..0f386d8520f7 100644 --- a/tools/codegen/gen.py +++ b/tools/codegen/gen.py @@ -9,6 +9,7 @@ import argparse import pathlib import functools +import json from tools.codegen.code_template import CodeTemplate from tools.codegen.model import * @@ -124,6 +125,18 @@ def concatMap(func: Callable[[T], Sequence[S]], xs: Sequence[T]) -> Iterator[S]: for r in func(x): yield r +def cpp_string(s: str) -> str: + """Convert a python string into a c++ string literal """ + s = s.replace('\\', '\\\\') + s = s.replace('"', '\\"') + s = s.replace('\a', '\\a') + s = s.replace('\b', '\\b') + s = s.replace('\f', '\\f') + s = s.replace('\n', '\\n') + s = s.replace('\v', '\\v') + s = s.replace('\t', '\\t') + return f'"{s}"' + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # # C++ CODE GENERATION @@ -268,7 +281,7 @@ def func(f: NativeFunction) -> Optional[str]: # def registration only happens in TypeDefault def_registration = "" if dispatch is None: - def_registration = f'm.def("{f.func}");\n' + def_registration = f'm.def({cpp_string(str(f.func))});\n' impl_registration = "" if not def_only and not f.manual_kernel_registration and (dispatch is not None or f.dispatch is None): @@ -881,9 +894,12 @@ def compute_registration_declarations(f: NativeFunction) -> str: returns_type = dispatcher.returns_type(f.func.returns) args = dispatcher.arguments(f.func) args_str = ', '.join(map(str, args)) - dispatch = f.dispatch is not None - math = dispatch and 'Math' in f.dispatch # type: ignore - return f"""{returns_type} {name}({args_str}); // {{"schema": "aten::{f.func}", "dispatch": "{dispatch}", "math": "{math}"}} + comment_data : Dict[str, str] = { + 'schema': f'aten::{f.func}', + 'dispatch': str(f.dispatch is not None), + 'math': str(f.dispatch is not None and 'Math' in f.dispatch) + } + return f"""{returns_type} {name}({args_str}); // {json.dumps(comment_data)} """ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # diff --git a/torch/csrc/utils/python_arg_parser.cpp b/torch/csrc/utils/python_arg_parser.cpp index 81c55b83bf8c..90a0bb1c92f4 100644 --- a/torch/csrc/utils/python_arg_parser.cpp +++ b/torch/csrc/utils/python_arg_parser.cpp @@ -503,6 +503,62 @@ static inline std::vector parse_intlist_args(const std::string& s, int6 return args; } +// Parse a string literal to remove quotes and escape sequences +static std::string parse_string_literal(c10::string_view str) { + TORCH_CHECK(str.length() >= 2, "String defaults must be quoted"); + + if (str.front() == '"') { + TORCH_CHECK(str.back() == '"', + "Mismatched quotes in string default: ", str); + } else { + TORCH_CHECK(str.front() == '\'' && str.back() == '\'', + "Invalid quotes in string default: ", str) + } + + std::string parsed; + parsed.reserve(str.size()); + for (size_t i = 1; i < str.size() - 1;) { + if (str[i] != '\\') { + parsed.push_back(str[i]); + ++i; + continue; + } + + // Handle escape sequences + TORCH_CHECK(i < str.size() - 2, "String ends with escaped final quote: ", str) + char c = str[i + 1]; + switch (c) { + case '\\': + case '\'': + case '\"': + break; + case 'a': + c = '\a'; + break; + case 'b': + c = '\b'; + break; + case 'f': + c = '\f'; + break; + case 'n': + c = '\n'; + break; + case 'v': + c = '\v'; + break; + case 't': + c = '\t'; + break; + default: + TORCH_CHECK(false, "Unsupported escape sequence in string default: \\", str[i + 1]); + } + parsed.push_back(c); + i += 2; + } + return parsed; +} + void FunctionParameter::set_default_str(const std::string& str) { if (str == "None") { allow_none = true; @@ -558,8 +614,8 @@ void FunctionParameter::set_default_str(const std::string& str) { throw std::runtime_error("invalid device: " + str); } } else if (type_ == ParameterType::STRING) { - if (str != "None" && str != "") { - throw std::runtime_error("invalid default string: " + str); + if (str != "None") { + default_string = parse_string_literal(str); } } } diff --git a/torch/csrc/utils/python_arg_parser.h b/torch/csrc/utils/python_arg_parser.h index 0454e7e2af51..928232cd099a 100644 --- a/torch/csrc/utils/python_arg_parser.h +++ b/torch/csrc/utils/python_arg_parser.h @@ -188,6 +188,7 @@ struct PythonArgs { inline c10::optional memoryformatOptional(int i); inline at::QScheme toQScheme(int i); inline std::string string(int i); + inline std::string stringWithDefault(int i, const std::string& default_str); inline c10::optional stringOptional(int i); inline PyObject* pyobject(int i); inline int64_t toInt64(int i); @@ -226,6 +227,7 @@ struct FunctionParameter { at::SmallVector numpy_python_names; at::Scalar default_scalar; std::vector default_intlist; + std::string default_string; union { bool default_bool; int64_t default_int; @@ -530,7 +532,11 @@ inline at::QScheme PythonArgs::toQScheme(int i) { } inline std::string PythonArgs::string(int i) { - if (!args[i]) return ""; + return stringWithDefault(i, signature.params[i].default_string); +} + +inline std::string PythonArgs::stringWithDefault(int i, const std::string& default_str) { + if (!args[i]) return default_str; return THPUtils_unpackString(args[i]); } From ed1552a48fb68c4a92821c14751fc63e3428800a Mon Sep 17 00:00:00 2001 From: Kurt Mohler Date: Tue, 6 Oct 2020 23:09:36 -0700 Subject: [PATCH 24/69] Add note about in-place weight modification for nn.Embedding (#45595) Summary: Fixes https://github.com/pytorch/pytorch/issues/26596 Pull Request resolved: https://github.com/pytorch/pytorch/pull/45595 Reviewed By: albanD Differential Revision: D24143456 Pulled By: mruberry fbshipit-source-id: a884a32809105ce16959b40ec745ec873b3c8375 --- torch/nn/modules/sparse.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/torch/nn/modules/sparse.py b/torch/nn/modules/sparse.py index f063ffa2e8eb..3589d4b815c9 100644 --- a/torch/nn/modules/sparse.py +++ b/torch/nn/modules/sparse.py @@ -50,6 +50,23 @@ class Embedding(Module): output. The gradient for this vector from :class:`~torch.nn.Embedding` is always zero. + .. note:: + When :attr:`max_norm` is not ``None``, :class:`Embedding`'s forward method will modify the + :attr:`weight` tensor in-place. Since tensors needed for gradient computations cannot be + modified in-place, performing a differentiable operation on ``Embedding.weight`` before + calling :class:`Embedding`'s forward method requires cloning ``Embedding.weight`` when + :attr:`max_norm` is not ``None``. For example:: + + n, d, m = 3, 5, 7 + embedding = nn.Embedding(n, d, max_norm=True) + W = torch.randn((m, d), requires_grad=True) + idx = torch.tensor([1, 2]) + a = embedding.weight.clone() @ W.t() # weight must be cloned for this to be differentiable + b = embedding(idx) @ W.t() # modifies weight in-place + out = (a.unsqueeze(0) + b.unsqueeze(1)) + loss = out.sigmoid().prod() + loss.backward() + Examples:: >>> # an Embedding module containing 10 tensors of size 3 From 317b6516bc6aa960a787aa63ce6bb54ccf54af81 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Tue, 6 Oct 2020 23:28:20 -0700 Subject: [PATCH 25/69] [quant] Add quantized::sigmoid that take output_scale/output_zero_point as input (#45882) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45882 Same changes as the stack for leaky_relu: https://github.com/pytorch/pytorch/pull/45702 Test Plan: Imported from OSS Reviewed By: z-a-f Differential Revision: D24129113 fbshipit-source-id: a26da33f877d3bdeea1976b69b2bd9369c2bf196 --- .../cpu/kernels/QuantizedOpKernels.cpp | 16 +----- .../ATen/native/quantized/cpu/qsigmoid.cpp | 55 ++++++++++++++++--- .../ATen/native/quantized/cpu/quantized_ops.h | 2 +- aten/src/ATen/native/quantized/library.cpp | 1 + test/quantization/test_quantized_op.py | 24 +++++++- 5 files changed, 73 insertions(+), 25 deletions(-) diff --git a/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp b/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp index a65e9f00f1d8..c08295b3514c 100644 --- a/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp +++ b/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp @@ -486,7 +486,8 @@ static void leaky_qrelu_out_kernel(Tensor& out, const Tensor& qx, }); } -void qsigmoid_kernel(const Tensor& qx, Tensor& qy) { +void qsigmoid_kernel( + const Tensor& qx, Tensor& qy, double output_scale, int64_t output_zero_point ) { int64_t zero_point = qx.q_zero_point(); float scale = qx.q_scale(); auto scale_vec = Vec256(scale); @@ -494,19 +495,6 @@ void qsigmoid_kernel(const Tensor& qx, Tensor& qy) { auto scale_neg_zp_premul_vec = scale_vec * zero_point_vec.neg(); AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qsigmoid", [&]() { - // Naive implemenentation: uses dequantize/execute/quantize routine - // - Output scale is set to 1.0 / 2^(BIT_NUM) - // - For signed types output zero point is set to 0 - // - For unsigned types output zero point is set to (qmax + qmin) / 2.0 - // See https://stackoverflow.com/a/34448562/3606192 for potential - // optimizations - float output_scale = 0.00390625; // 1.0 / 2^8 - int64_t output_zero_point = 0; - if (SCALAR_TYPE == at::kQInt32) { - output_scale = 2.3283064365386963e-10; // 1.0 / 2^32 - } else if (SCALAR_TYPE == at::kQInt8) { - output_zero_point = -128; - } float inv_output_scale = 1.0 / output_scale; qy = at::_empty_affine_quantized( diff --git a/aten/src/ATen/native/quantized/cpu/qsigmoid.cpp b/aten/src/ATen/native/quantized/cpu/qsigmoid.cpp index 5c2bcd859bed..33d2041d124f 100644 --- a/aten/src/ATen/native/quantized/cpu/qsigmoid.cpp +++ b/aten/src/ATen/native/quantized/cpu/qsigmoid.cpp @@ -17,15 +17,11 @@ namespace native { DEFINE_DISPATCH(qsigmoid_stub); #ifdef USE_PYTORCH_QNNPACK -// This ALWAYS outputs scale=1.0/256, dtype=quint8 -// The zero_point is 0 for qint32 and quint8, but -128 for qint8. -Tensor qnnpack_sigmoid(Tensor input) { +Tensor qnnpack_sigmoid( + Tensor input, double output_scale, int64_t output_zero_point) { TORCH_CHECK(input.ndimension() > 0, "qnnpack_sigmoid(): Got empty input tensor"); Tensor qy; - constexpr float output_scale = 1.0f / 256.0f; - constexpr int32_t output_zero_point = 0; - initQNNPACK(); Tensor input_contig = input.contiguous(input.suggest_memory_format()); @@ -76,17 +72,60 @@ Tensor qnnpack_sigmoid(Tensor input) { "failed to run QNNPACK sigmoid operator"); return qy; } + #endif // USE_PYTORCH_QNNPACK +// This ALWAYS outputs scale=1.0/256, dtype=quint8 +// The zero_point is 0 for qint32 and quint8, but -128 for qint8. Tensor sigmoid_quantized_cpu(const Tensor& qx) { #ifdef USE_PYTORCH_QNNPACK if (at::globalContext().qEngine() == at::QEngine::QNNPACK && qx.scalar_type() == kQUInt8) { - return qnnpack_sigmoid(qx); + constexpr double output_scale = 1.0f / 256.0f; + constexpr int64_t output_zero_point = 0; + return qnnpack_sigmoid(qx, output_scale, output_zero_point); } #endif // USE_PYTORCH_QNNPACK Tensor qy; - qsigmoid_stub(qx.device().type(), qx, qy); + AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qsigmoid", [&]() { + // Naive implemenentation: uses dequantize/execute/quantize routine + // - Output scale is set to 1.0 / 2^(BIT_NUM) + // - For signed types output zero point is set to 0 + // - For unsigned types output zero point is set to (qmax + qmin) / 2.0 + // See https://stackoverflow.com/a/34448562/3606192 for potential + // optimizations + double output_scale = 0.00390625; // 1.0 / 2^8 + int64_t output_zero_point = 0; + if (SCALAR_TYPE == at::kQInt32) { + output_scale = 2.3283064365386963e-10; // 1.0 / 2^32 + } else if (SCALAR_TYPE == at::kQInt8) { + output_zero_point = -128; + } + qsigmoid_stub(qx.device().type(), qx, qy, output_scale, output_zero_point); + }); return qy; } + +namespace { + +class QSigmoid final { + public: + static Tensor run(Tensor qx, double output_scale, int64_t output_zero_point) { +#ifdef USE_PYTORCH_QNNPACK + if (at::globalContext().qEngine() == at::QEngine::QNNPACK && + qx.scalar_type() == kQUInt8) { + return qnnpack_sigmoid(qx, output_scale, output_zero_point); + } +#endif // USE_PYTORCH_QNNPACK + Tensor qy; + qsigmoid_stub(qx.device().type(), qx, qy, output_scale, output_zero_point); + return qy; + } +}; + +TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) { + m.impl(TORCH_SELECTIVE_NAME("quantized::sigmoid"), TORCH_FN(QSigmoid::run)); +} +} // namespace + }} // namespace at::native diff --git a/aten/src/ATen/native/quantized/cpu/quantized_ops.h b/aten/src/ATen/native/quantized/cpu/quantized_ops.h index baf522731e6d..e7d3d50f6673 100644 --- a/aten/src/ATen/native/quantized/cpu/quantized_ops.h +++ b/aten/src/ATen/native/quantized/cpu/quantized_ops.h @@ -8,7 +8,7 @@ namespace native { using qrelu_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/); using qrelu_leaky_fn = void (*)(Tensor& /*out*/, const Tensor& /*qx*/, Scalar /*negval_*/); -using qsigmoid_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/); +using qsigmoid_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/, double output_scale, int64_t output_zero_point); using qhardsigmoid_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/); using qclamp_fn = void (*)( const at::Tensor& /*qx*/, diff --git a/aten/src/ATen/native/quantized/library.cpp b/aten/src/ATen/native/quantized/library.cpp index 39e7b03c140a..1ab399da88e6 100644 --- a/aten/src/ATen/native/quantized/library.cpp +++ b/aten/src/ATen/native/quantized/library.cpp @@ -159,6 +159,7 @@ TORCH_LIBRARY(quantized, m) { m.def(TORCH_SELECTIVE_SCHEMA("quantized::max_pool2d(Tensor qx, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::relu6(Tensor qx, bool inplace=False) -> Tensor")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::leaky_relu(Tensor qx, Scalar negative_slope, bool inplace, float output_scale, int output_zero_point) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::sigmoid(Tensor qx, float output_scale, int output_zero_point) -> Tensor")); } // According to #33294: The "_" prefix registration will be diff --git a/test/quantization/test_quantized_op.py b/test/quantization/test_quantized_op.py index 8d0a9d7063cc..60dd789af367 100644 --- a/test/quantization/test_quantized_op.py +++ b/test/quantization/test_quantized_op.py @@ -195,7 +195,7 @@ def _test_activation_function(self, X, fn_name, test_configs): dtype=torch_type) if output_is_observed: - extra_kwargs.update({'output_scale': scale, 'output_zero_point': zero_point}) + extra_kwargs.update({'output_scale': output_scale, 'output_zero_point': output_zero_point}) # Finds qY using in-place or non-in-place quantized operators. qY = q_op(qX, **extra_kwargs) @@ -253,7 +253,7 @@ def test_qrelu6(self, X): @override_qengines @given(X=hu.tensor(shapes=hu.array_shapes(1, 5, 1, 5), qparams=hu.qparams())) - def test_qsigmoid(self, X): + def test_sigmoid_non_observed(self, X): sigmoid_test_configs = [ { 'quantized_fn': [ @@ -266,6 +266,26 @@ def test_qsigmoid(self, X): ] self._test_activation_function(X, 'sigmoid', sigmoid_test_configs) + """Tests the correctness of the quantized::sigmoid op.""" + # TODO: enable after observed output is supported in qnnpack + # @override_qengines + @skipIfNoFBGEMM + @given(X=hu.tensor(shapes=hu.array_shapes(1, 5, 1, 5), + qparams=hu.qparams())) + def test_sigmoid(self, X): + sigmoid_test_configs = [ + { + 'quantized_fn': [ + torch.ops.quantized.sigmoid + ], + 'reference_fn': torch.sigmoid, + 'output_range': (0.0, 1.0), + 'change_zero_point': True, + 'output_is_observed': True, + } + ] + self._test_activation_function(X, 'sigmoid', sigmoid_test_configs) + """Tests the correctness of the quantized::hardsigmoid op.""" @override_qengines @given(X=hu.tensor(shapes=hu.array_shapes(1, 5, 1, 5), From 205ab4961232140d3a46a6a3f59b377243bb5407 Mon Sep 17 00:00:00 2001 From: Zachary DeVito Date: Tue, 6 Oct 2020 23:38:19 -0700 Subject: [PATCH 26/69] [packaging] simpler dependency plotting (#45686) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45686 This uses an online graphviz viewer. The code is simpler, and since it embeds all the data in the url you can just click the url from your terminal. Test Plan: Imported from OSS Reviewed By: ZolotukhinM Differential Revision: D24059157 Pulled By: zdevito fbshipit-source-id: 94d755cc2986c4226180b09ba36f8d040dda47cc --- test/test_package.py | 4 +- torch/package/exporter.py | 88 ++++++--------------------------------- 2 files changed, 14 insertions(+), 78 deletions(-) diff --git a/test/test_package.py b/test/test_package.py index 37d7b0f385a2..2ce777100484 100644 --- a/test/test_package.py +++ b/test/test_package.py @@ -186,8 +186,8 @@ def test_resnet(self): # check th debug graph has something reasonable: buf = StringIO() - e._write_dep_graph(failing_module='torch', output_file=buf) - self.assertIn('torchvision.models.resnet', buf.getvalue()) + debug_graph = e._write_dep_graph(failing_module='torch') + self.assertIn('torchvision.models.resnet', debug_graph) # we can now load the saved model i = PackageImporter(f1) diff --git a/torch/package/exporter.py b/torch/package/exporter.py index 2055cf945334..772be06c6e7a 100644 --- a/torch/package/exporter.py +++ b/torch/package/exporter.py @@ -13,7 +13,7 @@ from pathlib import Path import linecache import sys -from tempfile import NamedTemporaryFile +from urllib.parse import quote class PackageExporter: """ Exporters allow you to write packages of code, pickled python data, and @@ -168,83 +168,19 @@ def _module_exists(self, module_name: str) -> bool: except Exception: return False - def _write_dep_graph(self, failing_module=None, output_file=None): - depended_on : Dict[str, List[str]] = {} - for f, t in self.debug_deps: - if t not in depended_on: - depended_on[t] = [] - if f not in depended_on: - depended_on[f] = [] - depended_on[t].append(f) - - level : Dict[str, int] = {} - - def visit(x: str): - if x in level: - return level[x] - level[x] = 0 - for e in depended_on[x]: - level[x] = max(level[x], visit(e) + 1) - return level[x] - - for x in depended_on.keys(): - visit(x) - - nodes = [] - node_to_id = {} - n = 0 - for ft in self.debug_deps: - for e in ft: - if e not in node_to_id: - node_to_id[e] = n - extra = '' - if e == failing_module: - extra = ", color: 'red'" - nodes.append(f" {{id: {n}, label: '{e}', level: {level[e]}, shape: 'box'{extra}}},\n") - n += 1 - edges = [] - for f, t in self.debug_deps: - fn, tn = node_to_id[f], node_to_id[t] - edges.append(f" {{from: {fn}, to: {tn}, arrows: 'to'}},\n") - nodes_s, edges_s = ''.join(nodes), ''.join(edges) + def _write_dep_graph(self, failing_module=None): + edges = '\n'.join(f'"{f}" -> "{t}";' for f, t in self.debug_deps) + failing = '' if failing_module is None else f'{failing_module} [color=red];' template = f"""\ - - - - - - -
- - - - +digraph G {{ +rankdir = LR; +node [shape=box]; +{failing} +{edges} +}} """ - if output_file: - output_file.write(template) - return None - - with NamedTemporaryFile(mode='w', suffix='.html', delete=False) as tf: - tf.write(template) - return tf.name + arg = quote(template, safe='') + return f'https://dreampuf.github.io/GraphvizOnline/#{arg}' def _get_source_of_module(self, module: types.ModuleType) -> str: filename = getattr(module, '__file__', None) From 8cdb638c6242e9278a971733ccbac9fe0cdd2117 Mon Sep 17 00:00:00 2001 From: James Reed Date: Wed, 7 Oct 2020 00:13:34 -0700 Subject: [PATCH 27/69] [FX] Track use nodes in Node (#45775) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45775 Test Plan: Imported from OSS Reviewed By: zdevito Differential Revision: D24091082 Pulled By: jamesr66a fbshipit-source-id: b09bb6ae78436a7722fb135b8ec71464ef9587cd --- test/fx/quantization.py | 2 +- test/test_fx.py | 53 +++++++------- torch/fx/experimental/GraphManipulation.py | 31 +-------- torch/fx/experimental/Partitioner.py | 7 +- torch/fx/graph.py | 20 ++---- torch/fx/node.py | 80 +++++++++++++--------- torch/quantization/fx/pattern_utils.py | 2 +- 7 files changed, 83 insertions(+), 112 deletions(-) diff --git a/test/fx/quantization.py b/test/fx/quantization.py index a2de582937aa..ff6c98ac038b 100644 --- a/test/fx/quantization.py +++ b/test/fx/quantization.py @@ -164,7 +164,7 @@ def matches(modules, node, pattern, max_uses=sys.maxsize): self_match = pattern arg_matches = None - if node.uses > max_uses: + if len(node.users) > max_uses: return False if isinstance(self_match, type) and issubclass(self_match, torch.nn.Module): diff --git a/test/test_fx.py b/test/test_fx.py index 1451c5efe5cb..76217fce9e80 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -659,7 +659,7 @@ def forward(self, x): traced = symbolic_trace(st) traced.graph.lint(traced) stringed = str(traced.graph) - for s in ['args', 'kwargs', 'uses']: + for s in ['args', 'kwargs', '#users']: assert s in stringed def test_graph_fns(self): @@ -717,28 +717,6 @@ def forward(self, x): with self.assertRaisesRegex(AssertionError, message): traced(torch.rand(4, 3)) - def test_get_all_users_of(self): - graph : torch.fx.Graph = torch.fx.Graph() - a : torch.fx.Node = graph.create_node('placeholder', 'x') - b : torch.fx.Node = graph.create_node('call_module', 'linear_mod', args=(a,)) - c : torch.fx.Node = graph.create_node('get_attr', 'y_attr') - d : torch.fx.Node = graph.create_node('call_function', operator.add, args=(b, c)) - graph.output(d) - linear_mod : torch.nn.Module = torch.nn.Linear(3, 4) - add_param : torch.Tensor = torch.rand(3, 4) - gm : torch.fx.GraphModule = torch.fx.GraphModule( - {'linear_mod': linear_mod, 'y_attr' : add_param}, graph) - expected_uses: Dict[int, List[int]] = { - 0: [1], - 1: [3], - 2: [3], - 3: [4], - 4: [], - } - for i, node in enumerate(graph.nodes): - user_indexes = GraphManipulation.get_all_users_of(gm, i) - assert user_indexes == expected_uses[i] - def test_copy_no_remap(self): traced = symbolic_trace(SimpleTest()) g = traced.graph @@ -913,7 +891,7 @@ def test_erase_node_error(self): for node in traced.graph.nodes: # Test deleting with uses both in another Node and at the output if node.target in [operator.add, torch.relu]: - with self.assertRaisesRegex(RuntimeError, 'but it still had .* uses in the graph!'): + with self.assertRaisesRegex(RuntimeError, 'but it still had .* users in the graph'): traced.graph.erase_node(node) def test_find_uses(self): @@ -926,11 +904,12 @@ def test_find_uses(self): graph.output((y + z + u).node) graph.lint() - uses_of_x = x.node.find_uses() - self.assertEqual(len(uses_of_x), 3) - expected_ops = ['relu', 'add', 'neg'] - for node, expected in zip(uses_of_x, expected_ops): - assert expected in node.name + users_of_x = x.node.users + self.assertEqual(len(users_of_x), 3) + expected_ops = set(['relu', 'add', 'neg']) + for use in users_of_x: + assert any(use.name.startswith(prefix) for prefix in expected_ops) + def test_multi_insert_point(self): graph = torch.fx.Graph() @@ -948,5 +927,21 @@ def test_multi_insert_point(self): for node, expected in zip(graph.nodes, expected_ops): assert expected in node.name + def test_reassign_args_kwargs_uses(self): + graph = torch.fx.Graph() + x, y = Proxy(graph.placeholder('x')), Proxy(graph.placeholder('y')) + z = x + y + zed = z + z + z + graph.output(zed.node) + graph.lint() + + # zed = z + z + z -> zed = z + z + x + zed.node.args = (zed.node.args[0], x.node) + self.assertEqual(x.node.users.keys(), [z.node, zed.node]) + + # z = x + y -> z = y + y + z.node.args = (y.node, y.node) + self.assertEqual(x.node.users.keys(), [zed.node]) + if __name__ == '__main__': run_tests() diff --git a/torch/fx/experimental/GraphManipulation.py b/torch/fx/experimental/GraphManipulation.py index 0c5d18aa4fb2..e83fa866b512 100644 --- a/torch/fx/experimental/GraphManipulation.py +++ b/torch/fx/experimental/GraphManipulation.py @@ -1,37 +1,8 @@ -from typing import Dict, List +from typing import Dict from torch.fx.graph_module import GraphModule -from typing import Any from torch.fx.node import Node, Target, map_arg from torch.fx.graph import Graph - -"""find_use is used to find out if the node is another node's arg or kwargs.""" -def find_use(arg: Any, node: Node) -> bool: - if isinstance(arg, (tuple, list)): - return any(find_use(elem, node) for elem in arg) - elif isinstance(arg, dict): - return any(find_use(v, node) for k, v in arg.items()) - elif isinstance(arg, slice): - return any([find_use(arg.start, node), find_use(arg.stop, node), find_use(arg.step, node)]) - elif isinstance(arg, Node): - return arg is node - else: - return False - -def get_all_users_of(fx_module: GraphModule, index: int) -> List[int]: - """Given the graph(fx_module) and an index, return a list of all node indexes that use this node""" - graph = fx_module.graph - current_node = graph.nodes[index] - user_indexes: List[int] = [] - """if the node A is in node B's args, then B is the user of A - go through all the nodes, if the input node in any node's args, - then that node is the input node's user - """ - for i, n in enumerate(graph.nodes): - if find_use(n.args, current_node) or find_use(n.kwargs, current_node): - user_indexes.append(i) - return user_indexes - def replace_target_nodes_with( fx_module: GraphModule, old_op: str, diff --git a/torch/fx/experimental/Partitioner.py b/torch/fx/experimental/Partitioner.py index 605900cf974b..8384de8707f2 100644 --- a/torch/fx/experimental/Partitioner.py +++ b/torch/fx/experimental/Partitioner.py @@ -1,6 +1,5 @@ from torch.fx.graph_module import GraphModule from torch.fx.node import Node -from torch.fx.experimental import GraphManipulation from typing import Dict, List, Union class DAGNode: @@ -100,12 +99,10 @@ def get_input_nodes(self) -> List[Node]: def get_output_nodes(self) -> List[Node]: """Output nodes are the nodes that without any user inside this partition.""" output_nodes: List[Node] = [] + nodes_set = set(self.nodes) for node in self.nodes: - index = self.graph_module.graph.nodes.index(node) - user_indexes = GraphManipulation.get_all_users_of(self.graph_module, index) - user_nodes = {self.graph_module.graph.nodes[i] for i in user_indexes} # check if user nodes has an intersection with self.nodes - if not set(self.nodes).intersection(user_nodes): + if not nodes_set.intersection(node.users): output_nodes.append(node) return output_nodes diff --git a/torch/fx/graph.py b/torch/fx/graph.py index ed7618372b57..600fcb27a850 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -87,12 +87,6 @@ def graph_copy(self, g : 'Graph', val_map : Dict[Node, Node]) -> Optional[Argume val_map[node] = self.node_copy(node, lambda n : val_map[n]) return None - def _mark_uses(self, a: Argument): - def add_use(n: Node): - n.uses += 1 - return n - map_arg(a, add_use) - def create_node(self, op: str, target: Target, args: Optional[Tuple[Argument, ...]] = None, kwargs: Optional[Dict[str, Argument]] = None, @@ -100,8 +94,6 @@ def create_node(self, op: str, target: Target, assert op in ('call_function', 'call_method', 'get_attr', 'call_module', 'placeholder', 'output') args = () if args is None else args kwargs = {} if kwargs is None else kwargs - self._mark_uses(args) - self._mark_uses(kwargs) sanitized_name = self._register_name_used(name) if name is not None else self._name(target) n = Node(self, sanitized_name, op, target, args, kwargs) if self._insert_point is not None: @@ -127,10 +119,11 @@ def move_node_before(self, to_move : Node, before : Node): def erase_node(self, to_erase : Node): """ Erases the node `to_erase` from the `Graph`. Throws an exception if - there are still uses of that node in the `Graph`. + there are still users of that node in the `Graph`. """ - if to_erase.uses > 0: - raise RuntimeError(f'Tried to erase Node {to_erase} but it still had {to_erase.uses} uses in the graph!') + if len(to_erase.users) > 0: + raise RuntimeError(f'Tried to erase Node {to_erase} but it still had {len(to_erase.users)} ' + f'users in the graph: {to_erase.users}!') node_indices = [i for i, n in enumerate(self._nodes) if n == to_erase] for idx in reversed(node_indices): @@ -191,7 +184,6 @@ def node_copy(self, node: Node, arg_transform: Callable[[Node], Argument] = lamb return self.create_node(node.op, node.target, args, kwargs, name) def output(self, result: Argument): - self._mark_uses(result) return self.create_node(op='output', target='output', args=(result,)) def _name(self, target: Target) -> str: @@ -316,11 +308,11 @@ def format_node(n : Node) -> Optional[str]: placeholder_names.append(n.target) return None elif n.op == 'get_attr': - return f'%{n.name} : [uses={n.uses}] = self.{n.target}' + return f'%{n.name} : [#users={len(n.users)}] = self.{n.target}' elif n.op == 'output': return f'return {n.args[0]}' else: - return f'%{n.name} : [uses={n.uses}] = {n.op}[target={n.target}](' \ + return f'%{n.name} : [#users={len(n.users)}] = {n.op}[target={n.target}](' \ f'args = {format_arg(n.args)}, kwargs = {format_arg(n.kwargs)})' diff --git a/torch/fx/node.py b/torch/fx/node.py index 53abead5f044..458e1d3c66a8 100644 --- a/torch/fx/node.py +++ b/torch/fx/node.py @@ -1,5 +1,5 @@ # Nodes represent a definition of a value in our graph of operators. -from typing import TYPE_CHECKING, Union, Callable, Any, Tuple, List, Optional, Dict +from typing import TYPE_CHECKING, Union, Callable, Any, Set, Tuple, List, Optional, Dict import torch if TYPE_CHECKING: @@ -30,31 +30,47 @@ def __init__(self, graph: 'Graph', name: str, op: str, target: Target, assert isinstance(target, str) self.target = target # for method/module/function, the name of the method/module/function/attr # being invoked, e.g add, layer1, or torch.add - self.args = args - self.kwargs = kwargs - self.uses = 0 + self._args : Tuple[Argument, ...] = () + self._kwargs : Dict[str, Argument] = {} + self.args, self.kwargs = args, kwargs + # All of the nodes that use the value produced by this Node + # Note one user may correspond to several uses, e.g. the node fo `x + x` + # would appear once here, but represents two uses. + # + # Is a dict to act as an "ordered set". Keys are significant, value dont-care + self.users : Dict['Node', None] = {} - def find_uses(self) -> List['Node']: - """ - Find all nodes that use the value produced by `self`. The complexity of - this function is linear in the number of nodes * number of arguments to - each node. - - Note that len(find_uses()) is not necessarily equal to attribute `uses`. - This node could be used multiple times in the same `Node`. In that case, - the user node would appear once in the return value here, but `uses` would - account for the total number of times this Node is used by the user node. - e.g. a node for `x + x` would have two uses for the `x` node, but the - `x + x` node would appear once in the return from `find_uses` - """ - use_nodes : List[Node] = [] - for node in self.graph._nodes: - def record_use(arg_node : Node) -> None: - if arg_node == self and (len(use_nodes) == 0 or use_nodes[-1] != node): - use_nodes.append(node) - map_arg(node.args, record_use) - map_arg(node.kwargs, record_use) - return use_nodes + @property + def args(self) -> Tuple[Argument, ...]: + return self._args + + @args.setter + def args(self, a : Tuple[Argument, ...]): + self._update_args_kwargs(new_args=a, new_kwargs=self._kwargs) + + @property + def kwargs(self) -> Dict[str, Argument]: + return self._kwargs + + @kwargs.setter + def kwargs(self, k : Dict[str, Argument]): + self._update_args_kwargs(new_args=self._args, new_kwargs=k) + + def _update_args_kwargs(self, new_args : Tuple[Argument, ...], new_kwargs : Dict[str, Argument]): + old_defs = self._collect_all_defs() + self._args = new_args + self._kwargs = new_kwargs + new_defs = self._collect_all_defs() + for to_remove in old_defs - new_defs: + to_remove.users.pop(self) + for to_add in new_defs - old_defs: + to_add.users.setdefault(self) + + def _collect_all_defs(self) -> Set['Node']: + defs = set() + map_arg(self._args, lambda n: defs.add(n)) + map_arg(self._kwargs, lambda n: defs.add(n)) + return defs def __repr__(self) -> str: return self.name @@ -64,22 +80,22 @@ def replace_all_uses_with(self, replace_with : 'Node') -> List['Node']: Replace all uses of `self` in the Graph with the Node `replace_with`. Returns the list of nodes on which this change was made. """ - use_nodes : List[Node] = self.find_uses() - for use_node in use_nodes: + to_process = list(self.users) + for use_node in to_process: def maybe_replace_node(n : Node) -> Node: if n == self: - self.uses -= 1 return replace_with else: return n + new_args = map_arg(use_node.args, maybe_replace_node) - assert isinstance(new_args, tuple) - use_node.args = new_args new_kwargs = map_arg(use_node.kwargs, maybe_replace_node) + assert isinstance(new_args, tuple) assert isinstance(new_kwargs, dict) - use_node.kwargs = new_kwargs + use_node._update_args_kwargs(new_args, new_kwargs) - return use_nodes + assert len(self.users) == 0 + return to_process def map_arg(a: Argument, fn: Callable[[Node], Argument]) -> Argument: diff --git a/torch/quantization/fx/pattern_utils.py b/torch/quantization/fx/pattern_utils.py index fbdccbc5e3e2..2984da0b80fd 100644 --- a/torch/quantization/fx/pattern_utils.py +++ b/torch/quantization/fx/pattern_utils.py @@ -48,7 +48,7 @@ def is_match(modules, node, pattern, max_uses=sys.maxsize): self_match = pattern arg_matches = [] - if node.uses > max_uses: + if len(node.users) > max_uses: return False if isinstance(self_match, type) and issubclass(self_match, torch.nn.Module): From be45c3401af8186f97f0e2b269ff3bafaf16157f Mon Sep 17 00:00:00 2001 From: James Reed Date: Wed, 7 Oct 2020 01:55:17 -0700 Subject: [PATCH 28/69] [JIT] Make objects throw Python AttributeError on nonexistant attr access (#45911) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45911 Test Plan: Imported from OSS Reviewed By: robieta Differential Revision: D24140971 Pulled By: jamesr66a fbshipit-source-id: 046a2cffff898efad5bcc36a41bf992f36f555f9 --- test/jit/test_freezing.py | 4 ++-- test/jit/test_torchbind.py | 9 +++++++++ torch/csrc/Exceptions.cpp | 7 +++++++ torch/csrc/Exceptions.h | 8 ++++++++ torch/csrc/jit/api/object.h | 17 +++++++++++------ torch/csrc/jit/ir/attributes.h | 4 ++-- torch/csrc/jit/ir/ir.h | 6 +++--- torch/csrc/jit/python/script_init.cpp | 22 +++++++++++++++------- 8 files changed, 57 insertions(+), 20 deletions(-) diff --git a/test/jit/test_freezing.py b/test/jit/test_freezing.py index 696b97059d19..598f6f435af0 100644 --- a/test/jit/test_freezing.py +++ b/test/jit/test_freezing.py @@ -625,7 +625,7 @@ def _forward(self, x): self.assertFalse(mf.hasattr('sub')) self.assertFalse(mf.hasattr('a')) self.assertTrue(mf.hasattr('b')) - with self.assertRaisesRegex(RuntimeError, "TestModule does not have a field with name '_forward'"): + with self.assertRaisesRegex(AttributeError, "TestModule does not have a field with name '_forward'"): mf._forward(x) def test_freeze_module_with_inplace_mutable(self): @@ -1047,7 +1047,7 @@ def forward(self, x): self.assertFalse(mEval_freezed.hasattr('fc1')) self.assertFalse(mEval_freezed.hasattr('dropout2')) self.assertFalse(mEval_freezed.hasattr('fc2')) - with self.assertRaisesRegex(RuntimeError, "does not have a field with name 'state_dict'"): + with self.assertRaisesRegex(AttributeError, "does not have a field with name 'state_dict'"): print(mEval_freezed.state_dict()) buffer = io.BytesIO() torch.jit.save(mEval_freezed, buffer) diff --git a/test/jit/test_torchbind.py b/test/jit/test_torchbind.py index ee288b65551f..df482403f6c7 100644 --- a/test/jit/test_torchbind.py +++ b/test/jit/test_torchbind.py @@ -286,3 +286,12 @@ def test_profiler_custom_op(self): if e.name == '_TorchScriptTesting::take_an_instance': found_event = True self.assertTrue(found_event) + + def test_torchbind_getattr(self): + foo = torch.classes._TorchScriptTesting._StackString(["test"]) + self.assertEqual(None, getattr(foo, 'bar', None)) + + def test_torchbind_attr_exception(self): + foo = torch.classes._TorchScriptTesting._StackString(["test"]) + with self.assertRaisesRegex(AttributeError, 'does not have a field'): + foo.bar diff --git a/torch/csrc/Exceptions.cpp b/torch/csrc/Exceptions.cpp index eb735b73d541..73042117a45c 100644 --- a/torch/csrc/Exceptions.cpp +++ b/torch/csrc/Exceptions.cpp @@ -159,6 +159,13 @@ ValueError::ValueError(const char *format, ...) { va_end(fmt_args); } +AttributeError::AttributeError(const char* format, ...) { + va_list fmt_args; + va_start(fmt_args, format); + msg = formatMessage(format, fmt_args); + va_end(fmt_args); +} + void PyWarningHandler::process( const c10::SourceLocation& source_location, const std::string& msg, diff --git a/torch/csrc/Exceptions.h b/torch/csrc/Exceptions.h index 66e335b2bc76..c9d096270d2a 100644 --- a/torch/csrc/Exceptions.h +++ b/torch/csrc/Exceptions.h @@ -291,6 +291,14 @@ struct NotImplementedError : public PyTorchError { } }; +// Translates to Python AttributeError +struct AttributeError : public PyTorchError { + AttributeError(const char* format, ...) TORCH_FORMAT_FUNC(2, 3); + PyObject* python_type() override { + return PyExc_AttributeError; + } +}; + struct WarningMeta { WarningMeta(const c10::SourceLocation& _source_location, const std::string& _msg, diff --git a/torch/csrc/jit/api/object.h b/torch/csrc/jit/api/object.h index 305c254ad1c0..b32761316f09 100644 --- a/torch/csrc/jit/api/object.h +++ b/torch/csrc/jit/api/object.h @@ -12,6 +12,13 @@ using ResolverPtr = std::shared_ptr; using ObjectPtr = c10::intrusive_ptr; +// Throw this in C++ land if `attr` fails. This will be converted to a Python +// AttributeError by the Python binding code +class ObjectAttributeError : public std::runtime_error { + public: + ObjectAttributeError(const std::string& what) : std::runtime_error(what) {} +}; + struct TORCH_API Object { Object() {} Object(ObjectPtr _ivalue) : _ivalue_(std::move(_ivalue)) {} @@ -59,12 +66,10 @@ struct TORCH_API Object { if (auto r = _ivalue()->type()->findConstantSlot(name)) { return _ivalue()->type()->getConstant(*r); } - TORCH_CHECK( - false, - _ivalue()->type()->repr_str(), - " does not have a field with name '", - name, - "'"); + std::stringstream err; + err << _ivalue()->type()->repr_str() << " does not have a field with name '" + << name.c_str() << "'"; + throw ObjectAttributeError(err.str()); } c10::IValue attr(const std::string& name, c10::IValue or_else) const { diff --git a/torch/csrc/jit/ir/attributes.h b/torch/csrc/jit/ir/attributes.h index 21c4a0e96b8d..6a99ecfcf80d 100644 --- a/torch/csrc/jit/ir/attributes.h +++ b/torch/csrc/jit/ir/attributes.h @@ -138,8 +138,8 @@ struct TORCH_API GraphsAttr : public AttributeValue { ValueType value_; }; -struct AttributeError : public std::exception { - AttributeError(Symbol name, bool defined) { +struct IRAttributeError : public std::exception { + IRAttributeError(Symbol name, bool defined) { std::stringstream ss; if (!defined) { ss << "required keyword attribute '" << name.toUnqualString() diff --git a/torch/csrc/jit/ir/ir.h b/torch/csrc/jit/ir/ir.h index dbd9fb5ca755..71eca77809c1 100644 --- a/torch/csrc/jit/ir/ir.h +++ b/torch/csrc/jit/ir/ir.h @@ -810,7 +810,7 @@ struct TORCH_API Node { auto it = findAttr(name, true); auto* child = dynamic_cast(it->get()); if (child == nullptr) { - throw AttributeError(name, true); + throw IRAttributeError(name, true); } return child->value(); } @@ -825,7 +825,7 @@ struct TORCH_API Node { return v->name == name; }); if (required && it == values_.end()) { - throw AttributeError(name, false); + throw IRAttributeError(name, false); } AT_ASSERT(!required || it != values_.end()); return it; @@ -837,7 +837,7 @@ struct TORCH_API Node { return v->name == name; }); if (required && it == values_.end()) { - throw AttributeError(name, false); + throw IRAttributeError(name, false); } AT_ASSERT(!required || it != values_.end()); return it; diff --git a/torch/csrc/jit/python/script_init.cpp b/torch/csrc/jit/python/script_init.cpp index 30a1fce15b0a..b2049cb4362d 100644 --- a/torch/csrc/jit/python/script_init.cpp +++ b/torch/csrc/jit/python/script_init.cpp @@ -743,18 +743,26 @@ void initJitScriptBindings(PyObject* module) { .def( "getattr", [](Object& self, const std::string& name) { - return toPyObject(self.attr(name)); + try { + return toPyObject(self.attr(name)); + } catch (const ObjectAttributeError& err) { + throw AttributeError("%s", err.what()); + } }) .def( "__getattr__", [](Object& self, const std::string& name) -> py::object { - if (name == "__qualname__") { - return py::cast(self.type()->name()->name()); - } - if (auto method = self.find_method(name)) { - return py::cast(*method); + try { + if (name == "__qualname__") { + return py::cast(self.type()->name()->name()); + } + if (auto method = self.find_method(name)) { + return py::cast(*method); + } + return toPyObject(self.attr(name)); + } catch (const ObjectAttributeError& err) { + throw AttributeError("%s", err.what()); } - return toPyObject(self.attr(name)); }) .def( "hasattr", From bb99bea7747312c0eb74c2993290be3dac8acd30 Mon Sep 17 00:00:00 2001 From: peterjc123 Date: Wed, 7 Oct 2020 08:37:01 -0700 Subject: [PATCH 29/69] Compress NVCC flags for Windows (#45842) Summary: Fixes #{issue number} This makes the command line shorter. Also updates `randomtemp` in which the previous version has a limitation that the length of the argument cannot exceed 260. Pull Request resolved: https://github.com/pytorch/pytorch/pull/45842 Reviewed By: albanD Differential Revision: D24137088 Pulled By: ezyang fbshipit-source-id: f0b4240735306e302eb3887f54a2b7af83c9f5dc --- .jenkins/pytorch/win-test-helpers/build_pytorch.bat | 4 ++-- cmake/Dependencies.cmake | 2 +- cmake/public/cuda.cmake | 4 +++- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/.jenkins/pytorch/win-test-helpers/build_pytorch.bat b/.jenkins/pytorch/win-test-helpers/build_pytorch.bat index 0ddf3b4b462c..504d3b931bc7 100644 --- a/.jenkins/pytorch/win-test-helpers/build_pytorch.bat +++ b/.jenkins/pytorch/win-test-helpers/build_pytorch.bat @@ -95,7 +95,7 @@ if "%USE_CUDA%"=="1" ( copy %TMP_DIR_WIN%\bin\sccache.exe %TMP_DIR_WIN%\bin\nvcc.exe :: randomtemp is used to resolve the intermittent build error related to CUDA. - :: code: https://github.com/peterjc123/randomtemp + :: code: https://github.com/peterjc123/randomtemp-rust :: issue: https://github.com/pytorch/pytorch/issues/25393 :: :: Previously, CMake uses CUDA_NVCC_EXECUTABLE for finding nvcc and then @@ -103,7 +103,7 @@ if "%USE_CUDA%"=="1" ( :: in PATH, and then pass the arguments to it. :: Currently, randomtemp is placed before sccache (%TMP_DIR_WIN%\bin\nvcc) :: so we are actually pretending sccache instead of nvcc itself. - curl -kL https://github.com/peterjc123/randomtemp/releases/download/v0.3/randomtemp.exe --output %TMP_DIR_WIN%\bin\randomtemp.exe + curl -kL https://github.com/peterjc123/randomtemp-rust/releases/download/v0.2/randomtemp.exe --output %TMP_DIR_WIN%\bin\randomtemp.exe set RANDOMTEMP_EXECUTABLE=%TMP_DIR_WIN%\bin\nvcc.exe set CUDA_NVCC_EXECUTABLE=%TMP_DIR_WIN%\bin\randomtemp.exe set RANDOMTEMP_BASEDIR=%TMP_DIR_WIN%\bin diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 1bbb98fb3614..baf649f9449d 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -1479,7 +1479,7 @@ if(NOT INTERN_BUILD_MOBILE) add_definitions(-D_CRT_SECURE_NO_DEPRECATE=1) # skip unwanted includes from windows.h add_definitions(-DWIN32_LEAN_AND_MEAN) - list(APPEND CUDA_NVCC_FLAGS "-Xcompiler" "/wd4819" "-Xcompiler" "/wd4503" "-Xcompiler" "/wd4190" "-Xcompiler" "/wd4244" "-Xcompiler" "/wd4251" "-Xcompiler" "/wd4275" "-Xcompiler" "/wd4522") + list(APPEND CUDA_NVCC_FLAGS "-Xcompiler=/wd4819,/wd4503,/wd4190,/wd4244,/wd4251,/wd4275,/wd4522") endif() if(NOT MSVC) diff --git a/cmake/public/cuda.cmake b/cmake/public/cuda.cmake index c9ac37783d1c..a418724f6256 100644 --- a/cmake/public/cuda.cmake +++ b/cmake/public/cuda.cmake @@ -474,8 +474,10 @@ foreach(diag cc_clobber_ignored integer_sign_change useless_using_declaration unsigned_compare_with_zero declared_but_not_referenced bad_friend_decl) - list(APPEND CUDA_NVCC_FLAGS -Xcudafe --diag_suppress=${diag}) + list(APPEND SUPPRESS_WARNING_FLAGS --diag_suppress=${diag}) endforeach() +string(REPLACE ";" "," SUPPRESS_WARNING_FLAGS "${SUPPRESS_WARNING_FLAGS}") +list(APPEND CUDA_NVCC_FLAGS -Xcudafe ${SUPPRESS_WARNING_FLAGS}) # Set C++14 support set(CUDA_PROPAGATE_HOST_FLAGS_BLOCKLIST "-Werror") From 5640b79bf8a5412a0209a919c05c811d5427cc12 Mon Sep 17 00:00:00 2001 From: Michael Carilli Date: Wed, 7 Oct 2020 08:51:45 -0700 Subject: [PATCH 30/69] Allow consumer ops to sync on GraphRoot's gradient (#45787) Summary: Currently, a GraphRoot instance doesn't have an associated stream. Streaming backward synchronization logic assumes the instance ran on the default stream, and tells consumer ops to sync with the default stream. If the gradient the GraphRoot instance passes to consumer backward ops was populated on a non-default stream, we have a race condition. The race condition can exist even if the user doesn't give a manually populated gradient: ```python with torch.cuda.stream(side_stream): # loss.backward() implicitly synthesizes a one-element 1.0 tensor on side_stream # GraphRoot passes it to consumers, but consumers first sync on default stream, not side_stream. loss.backward() # Internally to backward(), streaming-backward logic takes over, stuff executes on the same stream it ran on in forward, # and the side_stream context is irrelevant. GraphRoot's interaction with its first consumer(s) is the spot where # the side_stream context causes a problem. ``` This PR fixes the race condition by associating a GraphRoot instance, at construction time, with the current stream(s) on the device(s) of the grads it will pass to consumers. (i think this relies on GraphRoot executing in the main thread, before backward thread(s) fork, because the grads were populated on the main thread.) The test demonstrates the race condition. It fails reliably without the PR's GraphRoot diffs and passes with the GraphRoot diffs. With the GraphRoot diffs, manually populating an incoming-gradient arg for `backward` (or `torch.autograd.grad`) and the actual call to `autograd.backward` will have the same stream-semantics relationship as any other pair of ops: ```python # implicit population is safe with torch.cuda.stream(side_stream): loss.backward() # explicit population in side stream then backward in side stream is safe with torch.cuda.stream(side_stream): kickoff_grad = torch.ones_like(loss) loss.backward(gradient=kickoff_grad) # explicit population in one stream then backward kickoff in another stream # is NOT safe, even with this PR's diffs, but that unsafety is consistent with # stream-semantics relationship of any pair of ops kickoff_grad = torch.ones_like(loss) with torch.cuda.stream(side_stream): loss.backward(gradient=kickoff_grad) # Safe, as you'd expect for any pair of ops kickoff_grad = torch.ones_like(loss) side_stream.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(side_stream): loss.backward(gradient=kickoff_grad) ``` This PR also adds the last three examples above to cuda docs and references them from autograd docstrings. Pull Request resolved: https://github.com/pytorch/pytorch/pull/45787 Reviewed By: nairbv Differential Revision: D24138376 Pulled By: albanD fbshipit-source-id: bc4cd9390f9f0358633db530b1b09f9c1080d2a3 --- docs/source/notes/cuda.rst | 41 ++++++++++++++++++++-- test/test_cuda.py | 42 +++++++++++++++++++++++ torch/autograd/__init__.py | 14 +++++++- torch/csrc/autograd/functions/basic_ops.h | 8 ++++- torch/tensor.py | 6 ++++ 5 files changed, 106 insertions(+), 5 deletions(-) diff --git a/docs/source/notes/cuda.rst b/docs/source/notes/cuda.rst index a34b0d7231fb..333e7a12c172 100644 --- a/docs/source/notes/cuda.rst +++ b/docs/source/notes/cuda.rst @@ -104,7 +104,7 @@ To get an idea of the precision and speed, see the example code below: ab_fp32 = a @ b # takes 0.11s on GA100 error = (ab_fp32 - ab_full).abs().max() # 0.0031 relative_error = error / mean # 0.000039 - + From the above example, we can see that with TF32 enabled, the speed is ~7x faster, relative error compared to double precision is approximately 2 orders of magnitude larger. If the full FP32 precision is needed, users can disable TF32 by: @@ -189,6 +189,41 @@ necessary synchronization when data is moved around, as explained above. However, when using non-default streams, it is the user's responsibility to ensure proper synchronization. +.. _bwd-cuda-stream-semantics: + +Stream semantics of backward passes +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Internally, each backward CUDA op runs on the same stream that was used for its corresponding forward op. + +When manually supplying CUDA tensor(s) as a backward pass's initial gradient(s) (e.g., +:func:`autograd.backward(..., grad_tensors=initial_grads)`, +:func:`autograd.grad(..., grad_outputs=initial_grads)`, or +:meth:`tensor.backward(..., gradient=initial_grad)`), +the acts of + +1. populating the initial gradient(s) and +2. invoking the backward pass + +have the same stream-semantics relationship as any pair of ops:: + + # Safe, populating initial_grad and invoking backward are in the same stream context + with torch.cuda.stream(strm): + loss.backward(gradient=torch.ones_like(loss)) + + # Unsafe, populating initial_grad and invoking backward are in different stream contexts, + # without synchronization + initial_grad = torch.ones_like(loss) + with torch.cuda.stream(strm): + loss.backward(gradient=initial_grad) + + # Safe, with synchronization + initial_grad = torch.ones_like(loss) + strm.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(strm): + initial_grad.record_stream(strm) + loss.backward(gradient=initial_grad) + .. _CUDA stream: https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#streams .. _cuda-memory-management: @@ -403,7 +438,7 @@ The difference between :class:`~torch.nn.parallel.DistributedDataParallel` and uses multiprocessing where a process is created for each GPU, while :class:`~torch.nn.DataParallel` uses multithreading. By using multiprocessing, each GPU has its dedicated process, this avoids the performance overhead caused -by GIL of Python interpreter. +by GIL of Python interpreter. -If you use :class:`~torch.nn.parallel.DistributedDataParallel`, you could use +If you use :class:`~torch.nn.parallel.DistributedDataParallel`, you could use `torch.distributed.launch` utility to launch your program, see :ref:`distributed-launch`. diff --git a/test/test_cuda.py b/test/test_cuda.py index 498fd199066f..acb9e9ede194 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -1733,6 +1733,48 @@ def test_streaming_backwards_device_transfer(self): self.assertTrue(a.grad.sum().item() == 4 * size) self.assertTrue(b.grad.sum().item() == 4 * size) + def test_streaming_backward_sync_graph_root(self): + # This function tests if bwd ops running on a side stream properly sync with the GraphRoot. + # The potential bug it targets is a race condition. The test uses multiple trials and + # torch.cuda._sleep such that if the race condition exists, the test will almost certainly fail, + # but there's a chance it may spuriously pass. Passing does not guarantee the backend is bug-free, + # but failure does guarantee there is a bug. + fwd_bwd_op_stream = torch.cuda.Stream() + bwd_ambient_stream = torch.cuda.Stream() + # We need these streams to be different otherwise the test is meaningless. + self.assertTrue(fwd_bwd_op_stream != bwd_ambient_stream) + + size = int(1e3) + + a = torch.full((size,), 2.0, device="cuda", requires_grad=True) + b = torch.full((size,), 3.0, device="cuda", requires_grad=True) + + # I don't think we need any manual record_streams below. + # a and b remain in scope for the entire test. + # c and grad remain in scope for each iteration, and there's a full sync between iterations. + for trial in range(5): + torch.cuda.synchronize() + a.grad = b.grad = None + with torch.cuda.stream(fwd_bwd_op_stream): + c = a * b + + with torch.cuda.stream(bwd_ambient_stream): + torch.cuda.synchronize() + # Long-running dummy kernel on bwd_ambient_stream delays filling of grad + torch.cuda._sleep(int(50 * get_cycles_per_ms())) + # Fills grad on bwd_ambient_stream + grad = torch.full((size,), float(trial + 1), device="cuda") + + # Bwd ops still run on fwd_bwd_ops_stream, so the following will likely fail if + # bwd ops don't sync with bwd_ambient_stream before consuming grad. + torch.autograd.backward(tensors=c, grad_tensors=grad) + + # assertEquals below run on bwd_ambient_stream, so this test may also fail + # if backward() fails to sync with bwd_ambient_stream at the end. + with torch.no_grad(): + self.assertEqual(a.grad, grad * b) + self.assertEqual(b.grad, grad * a) + @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") @unittest.skipIf(not IS_SANDCASTLE, "Does not work on Sandcastle") def test_cuda_init_race(self): diff --git a/torch/autograd/__init__.py b/torch/autograd/__init__.py index d515eb49695d..4e44536d931c 100644 --- a/torch/autograd/__init__.py +++ b/torch/autograd/__init__.py @@ -95,6 +95,12 @@ def backward( If you have to use this function, make sure to reset the ``.grad`` fields of your parameters to ``None`` after use to break the cycle and avoid the leak. + .. note:: + + If you run any forward ops, create ``grad_tensors``, and/or call ``backward`` + in a user-specified CUDA stream context, see + :ref:`Stream semantics of backward passes`. + Arguments: tensors (sequence of Tensor): Tensors of which the derivative will be computed. @@ -153,6 +159,12 @@ def grad( leaves will still be computed, and will be accumulated into their ``.grad`` attribute. + .. note:: + + If you run any forward ops, create ``grad_outputs``, and/or call ``grad`` + in a user-specified CUDA stream context, see + :ref:`Stream semantics of backward passes`. + Arguments: outputs (sequence of Tensor): outputs of the differentiated function. inputs (sequence of Tensor): Inputs w.r.t. which the gradient will be @@ -184,7 +196,7 @@ def grad( grad_outputs=grad_outputs, retain_graph=retain_graph, create_graph=create_graph, - only_inputs=only_inputs, + only_inputs=only_inputs, allow_unused=allow_unused, ) diff --git a/torch/csrc/autograd/functions/basic_ops.h b/torch/csrc/autograd/functions/basic_ops.h index 48f20ec408b0..1a4615466ec2 100644 --- a/torch/csrc/autograd/functions/basic_ops.h +++ b/torch/csrc/autograd/functions/basic_ops.h @@ -68,7 +68,13 @@ struct TORCH_API UndefinedGradBackward : public Node { struct TORCH_API GraphRoot : public Node { GraphRoot(edge_list functions, variable_list inputs) : Node(std::move(functions)), - outputs(std::move(inputs)) {} + outputs(std::move(inputs)) { + // Ensures calls to stream() on a GraphRoot instance reflect current stream(s) + // on devices of root grad tensors at the time the instance is constructed. + for (const auto& t : outputs) { + add_input_metadata(t); + } + } variable_list apply(variable_list&& inputs) override { return outputs; diff --git a/torch/tensor.py b/torch/tensor.py index 9709c146c815..190685cb4570 100644 --- a/torch/tensor.py +++ b/torch/tensor.py @@ -192,6 +192,12 @@ def backward(self, gradient=None, retain_graph=None, create_graph=False): See :ref:`Default gradient layouts` for details on the memory layout of accumulated gradients. + .. note:: + + If you run any forward ops, create ``gradient``, and/or call ``backward`` + in a user-specified CUDA stream context, see + :ref:`Stream semantics of backward passes`. + Arguments: gradient (Tensor or None): Gradient w.r.t. the tensor. If it is a tensor, it will be automatically converted From 1bb2d41b6802da5e4e208946df687fb229135ad5 Mon Sep 17 00:00:00 2001 From: Rong Rong Date: Wed, 7 Oct 2020 08:58:03 -0700 Subject: [PATCH 31/69] Revert D20850851: caffe2/plan_executor: wait for 1 minute after exception and then abort Test Plan: revert-hammer Differential Revision: D20850851 (https://github.com/pytorch/pytorch/commit/3fbddb92b1be1f70edced886745116b8daeebb17) Original commit changeset: 330503775d80 fbshipit-source-id: 612c6c3c4d5586bc8ad00a112cd00fc74fb44243 --- caffe2/core/plan_executor.cc | 63 ------------------------------- caffe2/core/plan_executor_test.cc | 59 +---------------------------- 2 files changed, 1 insertion(+), 121 deletions(-) diff --git a/caffe2/core/plan_executor.cc b/caffe2/core/plan_executor.cc index c7c0200e5880..3f70e96fffc8 100644 --- a/caffe2/core/plan_executor.cc +++ b/caffe2/core/plan_executor.cc @@ -17,18 +17,10 @@ C10_DEFINE_bool( "If used we will handle exceptions in executor threads. " "This avoids SIGABRT but may cause process to deadlock"); -C10_DEFINE_int( - caffe2_plan_executor_exception_timeout, - 60, - "Number of seconds to wait for concurrent threads to stop on exception" - "before terminating."); - namespace caffe2 { namespace { -// ExceptionWrapper holds an exception. If exception pointers are being used, -// it'll hold the original exception pointer otherwise just the message. class ExceptionWrapper { public: ExceptionWrapper() : hasException_(false) {} @@ -47,10 +39,6 @@ class ExceptionWrapper { #endif } - const std::string& what() const { - return exceptionMsg_; - } - operator bool() { return hasException_; } @@ -63,33 +51,6 @@ class ExceptionWrapper { std::string exceptionMsg_; }; -// ExceptionWrapperTerminate terminates the program with the specified -// exception. This preserves the exception ptr and ExceptionTracer will -// correctly grab it on exit. -class ExceptionWrapperTerminate { - public: - explicit ExceptionWrapperTerminate(ExceptionWrapper&& ew) : ew_(std::move(ew)) {} - - ~ExceptionWrapperTerminate() { - ew_.rethrowException(); - } - - private: - ExceptionWrapper ew_; -}; - -// ScopeExitGuard runs the provided function when it's destructed. -class ScopeExitGuard { - public: - explicit ScopeExitGuard(std::function&& f) : f_(std::move(f)) {} - ~ScopeExitGuard() { - f_(); - } - - private: - std::function f_; -}; - struct NetDefInfo { const NetDef* netDef; // in order to keep the "override existing nets" on the top-level workflow, @@ -499,16 +460,9 @@ bool ExecuteStepRecursive(ExecutionStepWrapper& stepWrapper) { << " with " << step.substep().size() << " concurrent substeps"; std::atomic next_substep{0}; - std::condition_variable cv; - std::atomic done{0}; std::mutex exception_mutex; ExceptionWrapper first_exception; auto worker = [&]() { - ScopeExitGuard on_exit([&] { - done += 1; - cv.notify_all(); - }); - auto num_substeps = compiledStep->recurringSubsteps.size(); int substep_id = next_substep++ % num_substeps; if (compiledStep->gotFailure) { @@ -546,23 +500,6 @@ bool ExecuteStepRecursive(ExecutionStepWrapper& stepWrapper) { for (size_t i = 0; i < numThreads; ++i) { threads.emplace_back(worker); } - - auto workersDone = [&] { return done == numThreads; }; - - // If we get an exception, try to wait for all threads to stop - // gracefully. - std::unique_lock guard(exception_mutex); - cv.wait(guard, [&] { return workersDone() || first_exception; }); - cv.wait_for( - guard, - std::chrono::seconds(FLAGS_caffe2_plan_executor_exception_timeout), - [&] { return workersDone(); }); - if (!workersDone() && first_exception) { - LOG(ERROR) << "failed to stop concurrent workers after exception: " - << first_exception.what(); - ExceptionWrapperTerminate(std::move(first_exception)); - } - for (auto& thread : threads) { thread.join(); } diff --git a/caffe2/core/plan_executor_test.cc b/caffe2/core/plan_executor_test.cc index 1b0eb0e718a2..86f145d72a09 100644 --- a/caffe2/core/plan_executor_test.cc +++ b/caffe2/core/plan_executor_test.cc @@ -67,29 +67,6 @@ class ErrorOp final : public Operator { REGISTER_CPU_OPERATOR(Error, ErrorOp); OPERATOR_SCHEMA(Error).NumInputs(0).NumOutputs(0); -static std::atomic blockingErrorRuns{0}; -class BlockingErrorOp final : public Operator { - public: - BlockingErrorOp(const OperatorDef& operator_def, Workspace* ws) - : Operator(operator_def, ws) {} - - bool RunOnDevice() override { - // First n op executions should block and then start throwing errors. - if (blockingErrorRuns.fetch_sub(1) >= 1) { - LOG(INFO) << "blocking"; - while (true) { - std::this_thread::sleep_for(std::chrono::hours(10)); - } - } else { - LOG(INFO) << "throwing"; - throw TestError(); - } - } -}; - -REGISTER_CPU_OPERATOR(BlockingError, BlockingErrorOp); -OPERATOR_SCHEMA(BlockingError).NumInputs(0).NumOutputs(0); - PlanDef parallelErrorPlan() { PlanDef plan_def; @@ -124,12 +101,10 @@ PlanDef parallelErrorPlan() { } struct HandleExecutorThreadExceptionsGuard { - HandleExecutorThreadExceptionsGuard(int timeout = 60) { + HandleExecutorThreadExceptionsGuard() { globalInit({ "caffe2", "--caffe2_handle_executor_threads_exceptions=1", - "--caffe2_plan_executor_exception_timeout=" + - caffe2::to_string(timeout), }); } @@ -164,38 +139,6 @@ TEST(PlanExecutorTest, ErrorAsyncPlan) { ASSERT_EQ(cancelCount, 1); } -TEST(PlanExecutorTest, BlockingErrorPlan) { - ASSERT_DEATH( - [] { - HandleExecutorThreadExceptionsGuard guard(/*timeout=*/1); - - PlanDef plan_def; - - std::string plan_def_template = R"DOC( - network { - name: "net" - op { - type: "BlockingError" - } - } - execution_step { - num_concurrent_instances: 2 - substep { - network: "net" - } - } - )DOC"; - - CAFFE_ENFORCE( - TextFormat::ParseFromString(plan_def_template, &plan_def)); - Workspace ws; - blockingErrorRuns = 1; - ws.RunPlan(plan_def); - FAIL() << "shouldn't have reached this point"; - }(), - "failed to stop concurrent workers after exception: test error"); -} - } // namespace caffe2 #endif From 5ce31b6f3f56e850ea3a5d4510487702303a727d Mon Sep 17 00:00:00 2001 From: neginraoof Date: Wed, 7 Oct 2020 09:18:03 -0700 Subject: [PATCH 32/69] [ONNX] Improve error handling for adaptive_pool (#45874) Summary: Duplicate of https://github.com/pytorch/pytorch/issues/43032 This update would also improve error handling for interpolate with 'area' mode. Pull Request resolved: https://github.com/pytorch/pytorch/pull/45874 Reviewed By: albanD Differential Revision: D24141266 Pulled By: bzinodev fbshipit-source-id: 7559f1d6af4f1ef3507c15a1aee76fe01fa433cd --- test/onnx/test_pytorch_onnx_onnxruntime.py | 9 ++++++++- torch/onnx/symbolic_opset9.py | 10 ++++++++-- torch/onnx/utils.py | 3 +-- 3 files changed, 17 insertions(+), 5 deletions(-) diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index 23d4879a8a4c..77577f687de0 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -1667,7 +1667,14 @@ def forward(self, x, y): y = torch.randn(16, 16, requires_grad=True) self.run_test(MyModel(), (x, y)) - @disableScriptTest() + def test_interpolate_adaptive_pooling_error(self): + x = torch.randn(1, 2, 6, requires_grad=True) + with self.assertRaises(RuntimeError) as cm: + self._interpolate(x, "area", True, True) + + with self.assertRaises(RuntimeError) as cm: + self._interpolate(x, "area", False, True) + def test_groupnorm(self): model = torch.nn.GroupNorm(3, 6, 0.002) x = torch.randn(4, 6, 180, 180, 180) diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py index 7e8b04bf1612..eed84e437b2c 100644 --- a/torch/onnx/symbolic_opset9.py +++ b/torch/onnx/symbolic_opset9.py @@ -826,7 +826,6 @@ def symbolic_fn(g, input, kernel_size, stride, padding, ceil_mode, count_include def _adaptive_pool(name, type, tuple_fn, fn=None): - @parse_args('v', 'is') def symbolic_fn(g, input, output_size): # _adaptive_pool is supported for cases where output_size is 1 for all dimensions, # by executing a GlobalPool. @@ -837,6 +836,10 @@ def symbolic_fn(g, input, output_size): # so we try using max_poolxd_with_indices, and if it is not possible # (input is not a complete tensor or output size not factor of input size) # then we call GlobalAveragePool and return None for the indices + try: + output_size = _parse_arg(output_size, 'is') + except Exception: + return sym_help._onnx_unsupported('adaptive pooling, since output_size is not constant.') if output_size == [1] * len(output_size) and type == "AveragePool": return g.op("GlobalAveragePool", input) if not input.isCompleteTensor(): @@ -849,7 +852,10 @@ def symbolic_fn(g, input, output_size): if mod != [0] * len(mod): if output_size == [1] * len(output_size): return g.op("GlobalMaxPool", input), None - return _unimplemented(name, 'output size that are not factor of input size') + if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK: + return _unimplemented(name, 'output size that are not factor of input size') + else: + return sym_help._onnx_unsupported(name + ', since output size is not factor of input size') k = [int(dim[i] / output_size[i]) for i in range(0, len(dim))] # call max_poolxd_with_indices to get indices in the output if type == "MaxPool": diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index 16e7af721ebf..a26a4cdf2332 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -1003,8 +1003,7 @@ def _run_symbolic_function(g, n, inputs, env, operator_export_type=OperatorExpor else: raise RuntimeError("ONNX export failed on an operator with unrecognized namespace {}::{}. " "If you are trying to export a custom operator, make sure you registered " - "it with the right domain and version. " - "Otherwise, please report a bug.".format(ns, op_name)) + "it with the right domain and version.".format(ns, op_name)) except RuntimeError: if operator_export_type == OperatorExportTypes.ONNX_FALLTHROUGH: return None From 5a2773702f131deba03313d1cc15f62347d2b68e Mon Sep 17 00:00:00 2001 From: Jane Xu Date: Wed, 7 Oct 2020 09:29:39 -0700 Subject: [PATCH 33/69] add test sharding to CUDA on linux (#45972) Summary: splits up all the cuda linux tests into 2 shards to decrease total test runtime Pull Request resolved: https://github.com/pytorch/pytorch/pull/45972 Reviewed By: malfet Differential Revision: D24163521 Pulled By: janeyx99 fbshipit-source-id: da6e88eb4305192fb287c4458c31199bf62354c0 --- .../cimodel/data/pytorch_build_definitions.py | 1 + .circleci/config.yml | 78 ++++++++++++++++--- 2 files changed, 69 insertions(+), 10 deletions(-) diff --git a/.circleci/cimodel/data/pytorch_build_definitions.py b/.circleci/cimodel/data/pytorch_build_definitions.py index ccd97a053516..3afe37b29f2d 100644 --- a/.circleci/cimodel/data/pytorch_build_definitions.py +++ b/.circleci/cimodel/data/pytorch_build_definitions.py @@ -288,6 +288,7 @@ def instantiate_configs(): rocm_version = None if compiler_name == "cuda": cuda_version = fc.find_prop("compiler_version") + restrict_phases = ["build", "test1", "test2"] elif compiler_name == "rocm": rocm_version = fc.find_prop("compiler_version") diff --git a/.circleci/config.yml b/.circleci/config.yml index 208e0d09eed0..1ee03badb92a 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -6668,7 +6668,7 @@ workflows: build_environment: "pytorch-linux-xenial-cuda9.2-cudnn7-py3-gcc7-build" docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda9.2-cudnn7-py3-gcc7" - pytorch_linux_test: - name: pytorch_linux_xenial_cuda9_2_cudnn7_py3_gcc7_test + name: pytorch_linux_xenial_cuda9_2_cudnn7_py3_gcc7_test1 requires: - pytorch_linux_xenial_cuda9_2_cudnn7_py3_gcc7_build filters: @@ -6677,7 +6677,21 @@ workflows: - master - /ci-all\/.*/ - /release\/.*/ - build_environment: "pytorch-linux-xenial-cuda9.2-cudnn7-py3-gcc7-test" + build_environment: "pytorch-linux-xenial-cuda9.2-cudnn7-py3-gcc7-test1" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda9.2-cudnn7-py3-gcc7" + use_cuda_docker_runtime: "1" + resource_class: gpu.medium + - pytorch_linux_test: + name: pytorch_linux_xenial_cuda9_2_cudnn7_py3_gcc7_test2 + requires: + - pytorch_linux_xenial_cuda9_2_cudnn7_py3_gcc7_build + filters: + branches: + only: + - master + - /ci-all\/.*/ + - /release\/.*/ + build_environment: "pytorch-linux-xenial-cuda9.2-cudnn7-py3-gcc7-test2" docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda9.2-cudnn7-py3-gcc7" use_cuda_docker_runtime: "1" resource_class: gpu.medium @@ -6706,10 +6720,18 @@ workflows: build_environment: "pytorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7-build" docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7" - pytorch_linux_test: - name: pytorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_test + name: pytorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_test1 requires: - pytorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_build - build_environment: "pytorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7-test" + build_environment: "pytorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7-test1" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7" + use_cuda_docker_runtime: "1" + resource_class: gpu.medium + - pytorch_linux_test: + name: pytorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_test2 + requires: + - pytorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_build + build_environment: "pytorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7-test2" docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7" use_cuda_docker_runtime: "1" resource_class: gpu.medium @@ -6780,7 +6802,21 @@ workflows: build_environment: "pytorch-libtorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7-build" docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7" - pytorch_linux_test: - name: pytorch_libtorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_test + name: pytorch_libtorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_test1 + requires: + - pytorch_libtorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_build + filters: + branches: + only: + - master + - /ci-all\/.*/ + - /release\/.*/ + build_environment: "pytorch-libtorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7-test1" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7" + use_cuda_docker_runtime: "1" + resource_class: gpu.medium + - pytorch_linux_test: + name: pytorch_libtorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_test2 requires: - pytorch_libtorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_build filters: @@ -6789,7 +6825,7 @@ workflows: - master - /ci-all\/.*/ - /release\/.*/ - build_environment: "pytorch-libtorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7-test" + build_environment: "pytorch-libtorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7-test2" docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7" use_cuda_docker_runtime: "1" resource_class: gpu.medium @@ -6806,7 +6842,21 @@ workflows: build_environment: "pytorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7-build" docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7" - pytorch_linux_test: - name: pytorch_linux_xenial_cuda11_0_cudnn8_py3_gcc7_test + name: pytorch_linux_xenial_cuda11_0_cudnn8_py3_gcc7_test1 + requires: + - pytorch_linux_xenial_cuda11_0_cudnn8_py3_gcc7_build + filters: + branches: + only: + - master + - /ci-all\/.*/ + - /release\/.*/ + build_environment: "pytorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7-test1" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7" + use_cuda_docker_runtime: "1" + resource_class: gpu.medium + - pytorch_linux_test: + name: pytorch_linux_xenial_cuda11_0_cudnn8_py3_gcc7_test2 requires: - pytorch_linux_xenial_cuda11_0_cudnn8_py3_gcc7_build filters: @@ -6815,7 +6865,7 @@ workflows: - master - /ci-all\/.*/ - /release\/.*/ - build_environment: "pytorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7-test" + build_environment: "pytorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7-test2" docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7" use_cuda_docker_runtime: "1" resource_class: gpu.medium @@ -6826,10 +6876,18 @@ workflows: build_environment: "pytorch-libtorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7-build" docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7" - pytorch_linux_test: - name: pytorch_libtorch_linux_xenial_cuda11_0_cudnn8_py3_gcc7_test + name: pytorch_libtorch_linux_xenial_cuda11_0_cudnn8_py3_gcc7_test1 + requires: + - pytorch_libtorch_linux_xenial_cuda11_0_cudnn8_py3_gcc7_build + build_environment: "pytorch-libtorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7-test1" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7" + use_cuda_docker_runtime: "1" + resource_class: gpu.medium + - pytorch_linux_test: + name: pytorch_libtorch_linux_xenial_cuda11_0_cudnn8_py3_gcc7_test2 requires: - pytorch_libtorch_linux_xenial_cuda11_0_cudnn8_py3_gcc7_build - build_environment: "pytorch-libtorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7-test" + build_environment: "pytorch-libtorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7-test2" docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7" use_cuda_docker_runtime: "1" resource_class: gpu.medium From b186831c08e0e4e447eedb8a5cfab582995d37f9 Mon Sep 17 00:00:00 2001 From: Yinghai Lu Date: Wed, 7 Oct 2020 09:53:52 -0700 Subject: [PATCH 34/69] Automatic update of fbcode/foxi to 6a4e19a2aaf7ae4b9fa9597526e65b395d5e79ad (#45951) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45951 Pull Request resolved: https://github.com/pytorch/glow/pull/4966 Previous import was 4aba696ec8f31794fd42880346dc586486205e0a Included changes: - **[6a4e19a](https://github.com/houseroad/foxi/commit/6a4e19a)**: Add fatal error value (#20) Test Plan: build Reviewed By: houseroad Differential Revision: D24156364 fbshipit-source-id: f833ada8d6586865e1831e2c8c632e3844c7b6a1 --- third_party/foxi | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/foxi b/third_party/foxi index 4aba696ec8f3..6a4e19a2aaf7 160000 --- a/third_party/foxi +++ b/third_party/foxi @@ -1 +1 @@ -Subproject commit 4aba696ec8f31794fd42880346dc586486205e0a +Subproject commit 6a4e19a2aaf7ae4b9fa9597526e65b395d5e79ad From 30bf799f9c2ed4bf9e837eecfc0f540edb0dc7f5 Mon Sep 17 00:00:00 2001 From: Nikita Vedeneev Date: Wed, 7 Oct 2020 10:21:12 -0700 Subject: [PATCH 35/69] `torch.matrix_exp` doc fix (#45909) Summary: As per title. Pull Request resolved: https://github.com/pytorch/pytorch/pull/45909 Reviewed By: dzhulgakov Differential Revision: D24147314 Pulled By: albanD fbshipit-source-id: fc21094f4dbdd04cc2063a9639b9d1f5728cb53f --- torch/_torch_docs.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 28f9ebf1a585..4ad620d4abd7 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -4292,16 +4292,15 @@ def merge_dicts(*dicts): add_docstr(torch.matrix_exp, r""" -matrix_power(input) -> Tensor - Returns the matrix exponential. Supports batched input. For a matrix ``A``, the matrix exponential is defined as .. math:: - \exp^A = \sum_{k=0}^\infty A^k / k!. + \mathrm{e}^A = \sum_{k=0}^\infty A^k / k! """ + r""" The implementation is based on: + Bader, P.; Blanes, S.; Casas, F. Computing the Matrix Exponential with an Optimized Taylor Polynomial Approximation. Mathematics 2019, 7, 1174. From 83d2c9a23250ff12fb894d070aa0b420427244b8 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 7 Oct 2020 10:24:57 -0700 Subject: [PATCH 36/69] [quant] Add quantized Sigmoid module (#45883) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45883 Test Plan: python test/test_quantization.py TestStaticQuantizedModule.test_sigmoid Imported from OSS Reviewed By: z-a-f Differential Revision: D24129116 fbshipit-source-id: aa960549509c60374012f35b1f5be39e90418099 --- test/quantization/test_quantized_module.py | 3 +++ torch/nn/quantized/modules/__init__.py | 3 ++- torch/nn/quantized/modules/activation.py | 21 +++++++++++++++++++++ 3 files changed, 26 insertions(+), 1 deletion(-) diff --git a/test/quantization/test_quantized_module.py b/test/quantization/test_quantized_module.py index f5c3a8e3e8d5..a1fbc308dfde 100644 --- a/test/quantization/test_quantized_module.py +++ b/test/quantization/test_quantized_module.py @@ -716,6 +716,9 @@ def test_elu(self): def test_leaky_relu(self): self._test_activation_module_impl("LeakyReLU", nn.LeakyReLU, nnq.LeakyReLU, {"negative_slope": 0.2}) + def test_sigmoid(self): + self._test_activation_module_impl("Sigmoid", nn.Sigmoid, nnq.Sigmoid, {}) + @given( num_embeddings=st.integers(10, 50), embedding_dim=st.integers(5, 50).filter(lambda x: x % 4 == 0), diff --git a/torch/nn/quantized/modules/__init__.py b/torch/nn/quantized/modules/__init__.py index fe6a5f6c3765..72595eb3cea4 100644 --- a/torch/nn/quantized/modules/__init__.py +++ b/torch/nn/quantized/modules/__init__.py @@ -2,7 +2,7 @@ import torch from torch.nn.modules.pooling import MaxPool2d -from .activation import ReLU, ReLU6, Hardswish, ELU, LeakyReLU +from .activation import ReLU, ReLU6, Hardswish, ELU, LeakyReLU, Sigmoid from .batchnorm import BatchNorm2d, BatchNorm3d from .normalization import LayerNorm, GroupNorm, InstanceNorm1d, \ InstanceNorm2d, InstanceNorm3d @@ -100,6 +100,7 @@ def from_float(mod): 'Hardswish', 'ELU', 'LeakyReLU', + 'Sigmoid', 'LayerNorm', 'GroupNorm', 'InstanceNorm1d', diff --git a/torch/nn/quantized/modules/activation.py b/torch/nn/quantized/modules/activation.py index f2017c85f0fd..366e1e63a039 100644 --- a/torch/nn/quantized/modules/activation.py +++ b/torch/nn/quantized/modules/activation.py @@ -149,3 +149,24 @@ def _get_name(self): def from_float(cls, mod): scale, zero_point = mod.activation_post_process.calculate_qparams() return cls(float(scale), int(zero_point), mod.negative_slope, mod.inplace) + +class Sigmoid(torch.nn.Sigmoid): + r"""This is the quantized equivalent of :class:`~torch.nn.LeakyReLU`. + + Args: + scale: quantization scale of the output tensor + zero_point: quantization zero point of the output tensor + """ + + def __init__(self, output_scale: float, output_zero_point: int): + super().__init__() + self.output_scale = output_scale + self.output_zero_point = output_zero_point + + def forward(self, input): + return torch.ops.quantized.sigmoid(input, self.output_scale, self.output_zero_point) + + @classmethod + def from_float(cls, mod): + output_scale, output_zero_point = mod.activation_post_process.calculate_qparams() + return cls(float(output_scale), int(output_zero_point)) From 9679e1affcded109201a4b33900943cf452c2ae0 Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Wed, 7 Oct 2020 10:50:50 -0700 Subject: [PATCH 37/69] annotate torch.autograd.* modules (#45004) Summary: Fixes https://github.com/pytorch/pytorch/issues/44638 Pull Request resolved: https://github.com/pytorch/pytorch/pull/45004 Reviewed By: VitalyFedyunin Differential Revision: D24113562 Pulled By: ezyang fbshipit-source-id: a85018b7e08b2fe6cf2bc14a217eb418cb2b9de4 --- mypy.ini | 21 ------ torch/_C/__init__.pyi.in | 10 +++ torch/_C/_autograd.pyi | 8 ++- torch/_C/_functions.pyi | 12 ++++ torch/autograd/function.py | 22 ++++--- torch/autograd/functional.py | 12 ++-- torch/autograd/gradcheck.py | 47 +++++++------- torch/autograd/profiler.py | 121 ++++++++++++++++++++--------------- torch/autograd/variable.py | 5 +- 9 files changed, 143 insertions(+), 115 deletions(-) create mode 100644 torch/_C/_functions.pyi diff --git a/mypy.ini b/mypy.ini index ea7bdb1a83ed..af39fd619732 100644 --- a/mypy.ini +++ b/mypy.ini @@ -180,27 +180,6 @@ ignore_errors = True [mypy-torch.utils.hipify.hipify_python] ignore_errors = True -[mypy-torch.autograd._functions.tensor] -ignore_errors = True - -[mypy-torch.autograd.function] -ignore_errors = True - -[mypy-torch.autograd.functional] -ignore_errors = True - -[mypy-torch.autograd.profiler] -ignore_errors = True - -[mypy-torch.autograd.gradcheck] -ignore_errors = True - -[mypy-torch.autograd.anomaly_mode] -ignore_errors = True - -[mypy-torch.autograd.variable] -ignore_errors = True - [mypy-torch.nn.quantized.modules.batchnorm] ignore_errors = True diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index f1e96e31d994..2ad2f647c2af 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -93,6 +93,7 @@ def DisableTorchFunction(): ... # Defined in torch/csrc/utils/tensor_layouts.cpp strided : layout = ... sparse_coo : layout = ... +_mkldnn : layout = ... # Defined in torch/csrc/MemoryFormat.cpp class memory_format: ... @@ -268,6 +269,10 @@ def import_ir_module_from_buffer( class Graph: ... +# Defined in torch/csrc/jit/ir/ir.h +class Value: + ... + # Defined in torch/aten/src/ATen/core/function_schema.h class FunctionSchema: ... @@ -389,6 +394,7 @@ def _is_xnnpack_enabled() -> _bool: ... # THPModule_isEnabledXNNPACK def _vmapmode_increment_nesting() -> _int: ... # THPModule_vmapmode_increment_nesting def _vmapmode_decrement_nesting() -> _int: ... # THPModule_vmapmode_decrement_nesting def _log_api_usage_once(str) -> None: ... # LogAPIUsageOnceFromPython +def _demangle(str) -> str: ... # c10::demangle # Defined in `valgrind.h` and `callgrind.h` respecitively. def valgrind_supported_platform() -> _bool: ... # NVALGRIND @@ -497,6 +503,10 @@ ${legacy_storage_base_hints} # TODO: where ${legacy_class_hints} +# Defined in torch/csrc/autograd/python_engine.cpp +class _ImperativeEngine: + ... + # Defined in torch/csrc/autograd/python_variable.cpp class _TensorBase(object): requires_grad: _bool diff --git a/torch/_C/_autograd.pyi b/torch/_C/_autograd.pyi index 653b705fe135..a154fb1948c1 100644 --- a/torch/_C/_autograd.pyi +++ b/torch/_C/_autograd.pyi @@ -11,10 +11,14 @@ class ProfilerState(Enum): class ProfilerConfig: - def __init__(self, state: ProfilerState, report_input_shapes: bool, profile_memory: bool) -> None: ... + def __init__( + self, state: ProfilerState, + report_input_shapes: bool, + profile_memory: bool, + with_stack: bool + ) -> None: ... ... - class ProfilerEvent: def cpu_elapsed_us(self, other: ProfilerEvent) -> float: ... def cpu_memory_usage(self) -> int: ... diff --git a/torch/_C/_functions.pyi b/torch/_C/_functions.pyi new file mode 100644 index 000000000000..4ad76e4f86e5 --- /dev/null +++ b/torch/_C/_functions.pyi @@ -0,0 +1,12 @@ +from torch import Tensor +from typing import AnyStr, List + +class UndefinedGrad: + def __init__(self) -> None: ... + def __call__(self, inputs: List[Tensor]) -> List[Tensor]: ... + ... + +class DelayedError: + def __init__(self, msg: AnyStr, num_inputs: int) -> None: ... + def __call__(self, inputs: List[Tensor]) -> List[Tensor]: ... + ... \ No newline at end of file diff --git a/torch/autograd/function.py b/torch/autograd/function.py index 6714444acdcf..0d546ceb28d6 100644 --- a/torch/autograd/function.py +++ b/torch/autograd/function.py @@ -1,11 +1,12 @@ import torch import torch._C as _C +from torch._C import _functions import torch.utils.hooks as hooks from torch._six import with_metaclass import functools import warnings from collections import OrderedDict -from typing import Any +from typing import Any, List, Optional class _ContextMethodMixin(object): @@ -84,7 +85,8 @@ class BackwardCFunction(_C._FunctionBase, _ContextMethodMixin, _HookMixin): _is_legacy = False def apply(self, *args): - return self._forward_cls.backward(self, *args) + # _forward_cls is defined by derived class + return self._forward_cls.backward(self, *args) # type: ignore class FunctionMeta(type): @@ -115,8 +117,8 @@ def __init__(cls, name, bases, attrs): return super(FunctionMeta, cls).__init__(name, bases, attrs) - -class Function(with_metaclass(FunctionMeta, _C._FunctionBase, _ContextMethodMixin, _HookMixin)): +# mypy doesn't understand `with_metaclass` from torch._six +class Function(with_metaclass(FunctionMeta, _C._FunctionBase, _ContextMethodMixin, _HookMixin)): # type: ignore r"""Records operation history and defines formulas for differentiating ops. See the Note on extending the autograd engine for more details on how to use @@ -227,7 +229,7 @@ def wrapper(ctx, *args): if not isinstance(outputs, tuple): outputs = (outputs,) - err_fn = torch._C._functions.DelayedError( + err_fn = _functions.DelayedError( b"trying to differentiate twice a function that was marked" b"with @once_differentiable", len(outputs)) @@ -330,7 +332,7 @@ def _unflatten(input, proto): # unflatten a list or tuple input into a nested list/tuple structure # specified by proto def unflatten_helper(input, proto): - res = [] + res: List[Optional[torch.Tensor]] = [] if hasattr(proto, "_jit_wrap"): return proto._jit_wrap(input) if not isinstance(proto, (list, tuple)): @@ -379,16 +381,16 @@ def _do_backward(self, gradients, retain_variables): del self._to_save_nested return result - def backward(self, *gradients: Any) -> Any: + def backward(self, *gradients: Any) -> Any: # type: ignore nested_gradients = _unflatten(gradients, self._nested_output) - result = self.backward_extended(*nested_gradients) + result = self.backward_extended(*nested_gradients) # type: ignore return tuple(_iter_None_tensors(result)) __call__ = _do_forward - def forward(self, *args: Any) -> Any: + def forward(self, *args: Any) -> Any: # type: ignore nested_tensors = _map_tensor_data(self._nested_input) - result = self.forward_extended(*nested_tensors) + result = self.forward_extended(*nested_tensors) # type: ignore del self._nested_input self._nested_output = result return tuple(_iter_tensors(result)) diff --git a/torch/autograd/functional.py b/torch/autograd/functional.py index 58e780c87d1b..70961cef9744 100644 --- a/torch/autograd/functional.py +++ b/torch/autograd/functional.py @@ -1,4 +1,5 @@ import torch +from typing import Tuple, List # Utility functions @@ -131,8 +132,8 @@ def _autograd_grad(outputs, inputs, grad_outputs=None, create_graph=False, retai assert isinstance(grad_outputs, tuple) assert len(outputs) == len(grad_outputs) - new_outputs = tuple() - new_grad_outputs = tuple() + new_outputs: Tuple[torch.Tensor, ...] = tuple() + new_grad_outputs: Tuple[torch.Tensor, ...] = tuple() for out, grad_out in zip(outputs, grad_outputs): if out is not None and out.requires_grad: new_outputs += (out,) @@ -153,7 +154,7 @@ def _fill_in_zeros(grads, refs, strict, create_graph, stage): if stage not in ["back", "back_trick", "double_back", "double_back_trick"]: raise RuntimeError("Invalid stage argument '{}' to _fill_in_zeros".format(stage)) - res = tuple() + res: Tuple[torch.Tensor, ...] = tuple() for i, grads_i in enumerate(grads): if grads_i is None: if strict: @@ -427,10 +428,11 @@ def jacobian(func, inputs, create_graph=False, strict=False): "jacobian") _check_requires_grad(outputs, "outputs", strict=strict) - jacobian = tuple() + jacobian: Tuple[torch.Tensor, ...] = tuple() for i, out in enumerate(outputs): - jac_i = tuple([] for _ in range(len(inputs))) + # mypy complains that expression and variable have different types due to the empty list + jac_i: Tuple[List[torch.Tensor]] = tuple([] for _ in range(len(inputs))) # type: ignore for j in range(out.nelement()): vj = _autograd_grad((out.reshape(-1)[j],), inputs, retain_graph=True, create_graph=create_graph) diff --git a/torch/autograd/gradcheck.py b/torch/autograd/gradcheck.py index b2bea4570c2a..531bcc6f27d8 100644 --- a/torch/autograd/gradcheck.py +++ b/torch/autograd/gradcheck.py @@ -5,7 +5,7 @@ from torch.overrides import is_tensor_like from itertools import product import warnings -from typing import Callable, Union, Optional +from typing import Callable, Union, Optional, Iterable, List def zero_gradients(x): if isinstance(x, torch.Tensor): @@ -29,15 +29,16 @@ def make_jacobian(input, num_out): lambda x: x is not None, (make_jacobian(elem, num_out) for elem in input))) if not jacobians: return None - return type(input)(jacobians) + return type(input)(jacobians) # type: ignore else: return None -def iter_tensors(x, only_requiring_grad=False): +def iter_tensors(x: Union[torch.Tensor, Iterable[torch.Tensor]], only_requiring_grad: bool = False) -> Iterable[torch.Tensor]: if is_tensor_like(x): - if x.requires_grad or not only_requiring_grad: - yield x + # mypy doesn't narrow type of `x` to torch.Tensor + if x.requires_grad or not only_requiring_grad: # type: ignore + yield x # type: ignore elif isinstance(x, container_abcs.Iterable) and not isinstance(x, str): for elem in x: for result in iter_tensors(elem, only_requiring_grad): @@ -137,7 +138,7 @@ def get_stride(size): indices = x_indices[i].tolist() + list(x_idx) d_idx = sum(indices[k] * x_stride[k] for k in range(len(x_size))) update_jacobians(x_value, x_idx, d_tensor, d_idx) - elif x_tensor.layout == torch._mkldnn: + elif x_tensor.layout == torch._mkldnn: # type: ignore # Use .data here to get around the version check x_tensor = x_tensor.data if len(input) != 1: @@ -163,7 +164,7 @@ def get_analytical_jacobian(input, output, nondet_tol=0.0, grad_out=1.0): if output.is_sparse: raise ValueError('Sparse output is not supported at gradcheck yet. ' 'Please call to_dense() on the output of fn for gradcheck.') - if output.layout == torch._mkldnn: + if output.layout == torch._mkldnn: # type: ignore raise ValueError('MKLDNN output is not supported at gradcheck yet. ' 'Please call to_dense() on the output of fn for gradcheck.') diff_input_list = list(iter_tensors(input, True)) @@ -303,13 +304,13 @@ def fail_test(msg): content = inp._values() if inp.is_sparse else inp # TODO: To cover more problematic cases, replace stride = 0 check with # "any overlap in memory" once we have a proper function to check it. - if content.layout is not torch._mkldnn and \ - not all(st > 0 or sz <= 1 for st, sz in zip(content.stride(), content.size())): - raise RuntimeError( - 'The {}th input has a dimension with stride 0. gradcheck only ' - 'supports inputs that are non-overlapping to be able to ' - 'compute the numerical gradients correctly. You should call ' - '.contiguous on the input before passing it to gradcheck.') + if content.layout is not torch._mkldnn: # type: ignore + if not all(st > 0 or sz <= 1 for st, sz in zip(content.stride(), content.size())): + raise RuntimeError( + 'The {}th input has a dimension with stride 0. gradcheck only ' + 'supports inputs that are non-overlapping to be able to ' + 'compute the numerical gradients correctly. You should call ' + '.contiguous on the input before passing it to gradcheck.') any_input_requiring_grad = True inp.retain_grad() if not any_input_requiring_grad: @@ -403,30 +404,30 @@ def not_reentrant_error(error_str=''): # check if the backward multiplies by grad_output output = _differentiable_outputs(func(*tupled_inputs)) if any([o.requires_grad for o in output]): - diff_input_list = list(iter_tensors(tupled_inputs, True)) + diff_input_list: List[torch.Tensor] = list(iter_tensors(tupled_inputs, True)) if not diff_input_list: raise RuntimeError("no Tensors requiring grad found in input") grads_input = torch.autograd.grad(output, diff_input_list, [torch.zeros_like(o, memory_format=torch.legacy_contiguous_format) for o in output], allow_unused=True) - for gi, i in zip(grads_input, diff_input_list): + for gi, di in zip(grads_input, diff_input_list): if gi is None: continue if isinstance(gi, torch.Tensor) and gi.layout != torch.strided: - if gi.layout != i.layout: - return fail_test('grad is incorrect layout (' + str(gi.layout) + ' is not ' + str(i.layout) + ')') + if gi.layout != di.layout: + return fail_test('grad is incorrect layout (' + str(gi.layout) + ' is not ' + str(di.layout) + ')') if gi.layout == torch.sparse_coo: - if gi.sparse_dim() != i.sparse_dim(): + if gi.sparse_dim() != di.sparse_dim(): return fail_test('grad is sparse tensor, but has incorrect sparse_dim') - if gi.dense_dim() != i.dense_dim(): + if gi.dense_dim() != di.dense_dim(): return fail_test('grad is sparse tensor, but has incorrect dense_dim') gi = gi.to_dense() - i = i.to_dense() + di = di.to_dense() if not gi.eq(0).all(): return fail_test('backward not multiplied by grad_output') - if gi.dtype != i.dtype or gi.device != i.device or gi.is_sparse != i.is_sparse: + if gi.dtype != di.dtype or gi.device != di.device or gi.is_sparse != di.is_sparse: return fail_test("grad is incorrect type") - if gi.size() != i.size(): + if gi.size() != di.size(): return fail_test('grad is incorrect size') if check_undefined_grad: diff --git a/torch/autograd/profiler.py b/torch/autograd/profiler.py index 8d33be090b27..eba7368cb03e 100644 --- a/torch/autograd/profiler.py +++ b/torch/autograd/profiler.py @@ -6,6 +6,8 @@ from collections import defaultdict, namedtuple from operator import attrgetter +from typing import List, Dict, Tuple, Optional + try: # Available in Python >= 3.2 from contextlib import ContextDecorator @@ -13,6 +15,13 @@ import functools class ContextDecorator(object): # type: ignore[no-redef] + + def __enter__(self): + raise NotImplementedError + + def __exit__(self, exc_type, exc_val, exc_tb): + raise NotImplementedError + def __call__(self, func): @functools.wraps(func) def wrapped(*args, **kwargs): @@ -78,13 +87,13 @@ def populate_cpu_children(self): # Algorithm has O(N * log(N)) complexity where N is number of # intervals for thread_id, thread_events in threads: - thread_events = sorted( + thread_events_ = sorted( thread_events, key=lambda event: [event.cpu_interval.start, -event.cpu_interval.end], ) - current_events = [] + current_events: List[FunctionEvent] = [] cur_end = 0 - for event in thread_events: + for event in thread_events_: while len(current_events) > 0: parent = current_events[-1] if event.cpu_interval.start >= parent.cpu_interval.end or \ @@ -253,7 +262,7 @@ def key_averages(self, group_by_input_shapes=False, group_by_stack_n=0): An EventList containing FunctionEventAvg objects. """ self.populate_cpu_children() - stats = defaultdict(FunctionEventAvg) + stats: Dict[Tuple[int, Tuple[int, int]], FunctionEventAvg] = defaultdict(FunctionEventAvg) def get_key(event, group_by_input_shapes, group_by_stack_n): key = [str(event.key), str(event.node_id)] @@ -413,6 +422,7 @@ def _check_finish(self): def table(self, sort_by=None, row_limit=100, header=None, top_level_events_only=False): self._check_finish() + assert self.function_events is not None return self.function_events.table( sort_by=sort_by, row_limit=row_limit, header=header, top_level_events_only=top_level_events_only @@ -421,16 +431,19 @@ def table(self, sort_by=None, row_limit=100, header=None, top_level_events_only= def export_chrome_trace(self, path): self._check_finish() + assert self.function_events is not None return self.function_events.export_chrome_trace(path) export_chrome_trace.__doc__ = EventList.export_chrome_trace.__doc__ def key_averages(self, group_by_input_shape=False, group_by_stack_n=0): self._check_finish() + assert self.function_events is not None return self.function_events.key_averages(group_by_input_shape, group_by_stack_n) key_averages.__doc__ = EventList.key_averages.__doc__ def total_average(self): self._check_finish() + assert self.function_events is not None return self.function_events.total_average() total_average.__doc__ = EventList.total_average.__doc__ @@ -440,6 +453,7 @@ def self_cpu_time_total(self): all self times across all the events. """ self._check_finish() + assert self.function_events is not None return self.function_events.self_cpu_time_total @@ -694,11 +708,11 @@ class FormattedTimesMixin(object): @property def cpu_time(self): - return 0.0 if self.count == 0 else 1.0 * self.cpu_time_total / self.count + return 0.0 if self.count == 0 else 1.0 * self.cpu_time_total / self.count # type: ignore @property def cuda_time(self): - return 0.0 if self.count == 0 else 1.0 * self.cuda_time_total / self.count + return 0.0 if self.count == 0 else 1.0 * self.cuda_time_total / self.count # type: ignore class Interval(object): @@ -719,24 +733,24 @@ def __init__( self, id, node_id, name, thread, cpu_start, cpu_end, fwd_thread=None, input_shapes=None, stack=None, scope=0, cpu_memory_usage=0, cuda_memory_usage=0, is_async=False, is_remote=True, sequence_nr=-1): - self.id = id - self.node_id = node_id - self.name = name - self.cpu_interval = Interval(cpu_start, cpu_end) - self.thread = thread - self.fwd_thread = fwd_thread - self.kernels = [] - self.count = 1 - self.cpu_children = [] - self.cpu_parent = None - self.input_shapes = input_shapes - self.stack = stack - self.scope = scope - self.cpu_memory_usage = cpu_memory_usage - self.cuda_memory_usage = cuda_memory_usage - self.is_async = is_async - self.is_remote = is_remote - self.sequence_nr = sequence_nr + self.id: int = id + self.node_id: int = node_id + self.name: str = name + self.cpu_interval: Interval = Interval(cpu_start, cpu_end) + self.thread: int = thread + self.fwd_thread: Optional[int] = fwd_thread + self.kernels: List[Kernel] = [] + self.count: int = 1 + self.cpu_children: List[FunctionEvent] = [] + self.cpu_parent: Optional[FunctionEvent] = None + self.input_shapes: Tuple[int, ...] = input_shapes + self.stack: List = stack + self.scope: int = scope + self.cpu_memory_usage: int = cpu_memory_usage + self.cuda_memory_usage: int = cuda_memory_usage + self.is_async: bool = is_async + self.is_remote: bool = is_remote + self.sequence_nr: int = sequence_nr def append_kernel(self, name, device, start, end): self.kernels.append(Kernel(name, device, Interval(start, end))) @@ -830,24 +844,24 @@ def __repr__(self): class FunctionEventAvg(FormattedTimesMixin): """Used to average stats over multiple FunctionEvent objects.""" def __init__(self): - self.key = None - self.count = 0 - self.node_id = 0 - self.is_async = False - self.is_remote = False - self.cpu_time_total = 0 - self.cuda_time_total = 0 - self.self_cpu_time_total = 0 - self.self_cuda_time_total = 0 - self.input_shapes = None - self.stack = None - self.scope = None - self.cpu_memory_usage = 0 - self.cuda_memory_usage = 0 - self.self_cpu_memory_usage = 0 - self.self_cuda_memory_usage = 0 - self.cpu_children = None - self.cpu_parent = None + self.key: Optional[str] = None + self.count: int = 0 + self.node_id: int = 0 + self.is_async: bool = False + self.is_remote: bool = False + self.cpu_time_total: int = 0 + self.cuda_time_total: int = 0 + self.self_cpu_time_total: int = 0 + self.self_cuda_time_total: int = 0 + self.input_shapes: Optional[List[List[int]]] = None + self.stack: Optional[List] = None + self.scope: Optional[int] = None + self.cpu_memory_usage: int = 0 + self.cuda_memory_usage: int = 0 + self.self_cpu_memory_usage: int = 0 + self.self_cuda_memory_usage: int = 0 + self.cpu_children: Optional[List[FunctionEvent]] = None + self.cpu_parent: Optional[FunctionEvent] = None def add(self, other): if self.key is None: @@ -950,6 +964,7 @@ def filter_stack_entry(entry): # and the CPU time of the cuda start event for the device def adjusted_time(cuda_record, cuda_records_map): assert cuda_record.device() != -1 + assert start_record is not None cuda_time_0 = cuda_records_map[(cuda_record.node_id(), cuda_record.device())] return cuda_time_0.cuda_elapsed_us(cuda_record) + start_record.cpu_elapsed_us(cuda_time_0) @@ -1102,6 +1117,8 @@ def parse_nvprof_trace(path): for row in conn.execute(marker_query): unique.see(row['marker_id']) evt = FunctionEvent(id=row['marker_id'], + node_id=0, # missing a node_id when calling FunctionEvent. This is just to ensure + # that pytorch doesn't crash when creating a FunctionEvent() object name=strings[row['name']], cpu_start=row['start_time'], cpu_end=row['end_time'], @@ -1215,15 +1232,15 @@ def build_table( # Have to use a list because nonlocal is Py3 only... SPACING_SIZE = 2 - row_format = [""] - header_sep = [""] - line_length = [-SPACING_SIZE] + row_format_lst = [""] + header_sep_lst = [""] + line_length_lst = [-SPACING_SIZE] MAX_STACK_ENTRY = 5 def add_column(padding, text_dir='>'): - row_format[0] += '{: ' + text_dir + str(padding) + '}' + (' ' * SPACING_SIZE) - header_sep[0] += '-' * padding + (' ' * SPACING_SIZE) - line_length[0] += padding + SPACING_SIZE + row_format_lst[0] += '{: ' + text_dir + str(padding) + '}' + (' ' * SPACING_SIZE) + header_sep_lst[0] += '-' * padding + (' ' * SPACING_SIZE) + line_length_lst[0] += padding + SPACING_SIZE add_column(name_column_width) for _ in headers[1:]: @@ -1237,10 +1254,10 @@ def add_column(padding, text_dir='>'): headers.append('Source Location') add_column(src_column_width, text_dir='<') - row_format = row_format[0] - header_sep = header_sep[0] - line_length = line_length[0] - add_column = None + row_format = row_format_lst[0] + header_sep = header_sep_lst[0] + line_length = line_length_lst[0] + add_column = None # type: ignore # Have to use a list because nonlocal is Py3 only... result = [] diff --git a/torch/autograd/variable.py b/torch/autograd/variable.py index 1008d741a6cf..307f82db34db 100644 --- a/torch/autograd/variable.py +++ b/torch/autograd/variable.py @@ -7,9 +7,10 @@ def __instancecheck__(cls, other): return isinstance(other, torch.Tensor) -class Variable(with_metaclass(VariableMeta, torch._C._LegacyVariableBase)): +# mypy doesn't understand torch._six.with_metaclass +class Variable(with_metaclass(VariableMeta, torch._C._LegacyVariableBase)): # type: ignore pass from torch._C import _ImperativeEngine as ImperativeEngine -Variable._execution_engine = ImperativeEngine() +Variable._execution_engine = ImperativeEngine() # type: ignore From 8fb32b9f5506970585c863923a2185ebe45e3984 Mon Sep 17 00:00:00 2001 From: Sam Estep Date: Wed, 7 Oct 2020 11:46:56 -0700 Subject: [PATCH 38/69] Parametrize # of longest tests in print_test_stats (#45941) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45941 This adds CLI options to the `test/print_test_stats.py` script for specifying how many of the longest tests should be printed. It also makes the following incidental changes: - The script now has a `--help` option to describe its usage. - The number of longest tests being shown is now displayed as a number, rather than in words. - The median time is now displayed with the label `median_time` instead of `mean_time`, is calculated using `statistics.median` instead of raw indexing and bit shifting, and is displayed even when there are only two tests in a class. Test Plan: Imported from OSS Reviewed By: walterddr, seemethere Differential Revision: D24154491 Pulled By: samestep fbshipit-source-id: 9fa402bf0fa56badd505f87f289ac9cca1862d6b --- test/print_test_stats.py | 61 ++++++++++++++++++++++++++++------------ 1 file changed, 43 insertions(+), 18 deletions(-) diff --git a/test/print_test_stats.py b/test/print_test_stats.py index 339f6800f61b..d1ccc3d36cc1 100755 --- a/test/print_test_stats.py +++ b/test/print_test_stats.py @@ -5,6 +5,7 @@ from glob import glob import json import os +import statistics import time import datetime @@ -42,21 +43,19 @@ def append(self, test_case): self.skipped_count += 1 if test_case.skipped else 0 self.errored_count += 1 if test_case.errored else 0 - def print_report(self): + def print_report(self, num_longest=3): sorted_tests = sorted(self.test_cases, key=lambda x: x.time) test_count = len(sorted_tests) print(f"class {self.name}:") print(f" tests: {test_count} failed: {self.failed_count} skipped: {self.skipped_count} errored: {self.errored_count}") print(f" run_time: {self.total_time:.2f} seconds") print(f" avg_time: {self.total_time/test_count:.2f} seconds") - if test_count > 2: - print(f" mean_time: {sorted_tests[test_count>>1].time:.2f} seconds") - print(" Three longest tests:") - for idx in [-1, -2, -3]: - print(f" {sorted_tests[idx].name} time: {sorted_tests[idx].time:.2f} seconds") - elif test_count > 0: - print(" Longest test:") - print(f" {sorted_tests[-1].name} time: {sorted_tests[-1].time:.2f} seconds") + if test_count >= 2: + print(f" median_time: {statistics.median(x.time for x in sorted_tests):.2f} seconds") + sorted_tests = sorted_tests[-num_longest:] + print(f" {len(sorted_tests)} longest tests:") + for test in reversed(sorted_tests): + print(f" {test.name} time: {test.time:.2f} seconds") print("") @@ -126,15 +125,42 @@ def send_report(reports): print("Scribe report status: {}".format(r.text)) r.raise_for_status() +def positive_integer(value): + parsed = int(value) + if parsed < 1: + raise argparse.ArgumentTypeError(f"{value} is not a natural number") + return parsed + if __name__ == '__main__': + import argparse import sys - if len(sys.argv) == 1: - print("Please specify test report folder") - sys.exit(0) + parser = argparse.ArgumentParser( + "Print statistics from test XML output.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--longest-of-class", + type=positive_integer, + default=3, + metavar="N", + help="how many longest tests to show for each class", + ) + parser.add_argument( + "--longest-of-run", + type=positive_integer, + default=10, + metavar="N", + help="how many longest tests to show from the entire run", + ) + parser.add_argument( + "folder", + help="test report folder", + ) + args = parser.parse_args() - reports = parse_reports(sys.argv[1]) + reports = parse_reports(args.folder) if len(reports) == 0: - print(f"No test reports found in {sys.argv[1]}") + print(f"No test reports found in {args.folder}") sys.exit(0) send_report(reports) @@ -143,13 +169,12 @@ def send_report(reports): total_time = 0 for name in sorted(reports.keys()): test_suite = reports[name] - test_suite.print_report() + test_suite.print_report(args.longest_of_class) total_time += test_suite.total_time longest_tests.extend(test_suite.test_cases) - if len(longest_tests) > 10: - longest_tests = sorted(longest_tests, key=lambda x: x.time)[-10:] + longest_tests = sorted(longest_tests, key=lambda x: x.time)[-args.longest_of_run:] print(f"Total runtime is {datetime.timedelta(seconds=int(total_time))}") - print("Ten longest tests of entire run:") + print(f"{len(longest_tests)} longest tests of entire run:") for test_case in reversed(longest_tests): print(f" {test_case.class_name}.{test_case.name} time: {test_case.time:.2f} seconds") From c8d76ff7dc21d426c7f2851803071820865a828f Mon Sep 17 00:00:00 2001 From: Pritam Damania Date: Wed, 7 Oct 2020 12:16:52 -0700 Subject: [PATCH 39/69] Improve logging in ProcessGroupNCCL for debugging purposes. (#45780) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45780 When training jobs running with NCCL fail sometimes it is hard to debug the reason of the failure and our logging doesn't provide enough information at times to narrow down the issue. To improve the debugging experience, I've enhanced our logging to add a lot more information about what the ProcessGroup is doing under the hood. #Closes: https://github.com/pytorch/pytorch/issues/45310 Sample output: ``` > I1002 15:18:48.539551 1822062 ProcessGroupNCCL.cpp:528] [Rank 2] NCCL watchdog thread started! > I1002 15:18:48.539533 1821946 ProcessGroupNCCL.cpp:492] [Rank 2] ProcessGroupNCCL initialized with following options: > NCCL_ASYNC_ERROR_HANDLING: 0 > NCCL_BLOCKING_WAIT: 1 > TIMEOUT(ms): 1000 > USE_HIGH_PRIORITY_STREAM: 0 > I1002 15:18:51.080338 1822035 ProcessGroupNCCL.cpp:530] [Rank 1] NCCL watchdog thread terminated normally > I1002 15:18:52.161218 1821930 ProcessGroupNCCL.cpp:385] [Rank 0] Wrote aborted communicator id to store: NCCLABORTEDCOMM:a0e17500002836080c8384c50000000100000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000 > I1002 15:18:52.161238 1821930 ProcessGroupNCCL.cpp:388] [Rank 0] Caught collective operation timeout for work: WorkNCCL(OpType=ALLREDUCE, TensorShape=[10], Timeout(ms)=1000) > I1002 15:18:52.162120 1821957 ProcessGroupNCCL.cpp:530] [Rank 0] NCCL watchdog thread terminated normally > I1002 15:18:58.539937 1822062 ProcessGroupNCCL.cpp:649] [Rank 2] Found key in store: NCCLABORTEDCOMM:a0e17500002836080c8384c50000000100000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000, from rank: 0, aborting appropriate communicators > I1002 15:19:34.740937 1822062 ProcessGroupNCCL.cpp:662] [Rank 2] Aborted communicators for key in store: NCCLABORTEDCOMM:a0e17500002836080c8384c50000000100000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000 > I1002 15:19:34.741678 1822062 ProcessGroupNCCL.cpp:530] [Rank 2] NCCL watchdog thread terminated normally ``` ghstack-source-id: 113731163 Test Plan: waitforbuildbot Reviewed By: osalpekar Differential Revision: D24093032 fbshipit-source-id: 240b03562f8ccccc3d872538f5e331df598ceca7 --- test/distributed/test_c10d.py | 5 +- torch/lib/c10d/ProcessGroup.cpp | 56 ++++++ torch/lib/c10d/ProcessGroup.hpp | 38 ++++ torch/lib/c10d/ProcessGroupNCCL.cpp | 170 ++++++++++++------ torch/lib/c10d/ProcessGroupNCCL.hpp | 35 ++-- .../c10d/test/ProcessGroupNCCLErrorsTest.cpp | 26 ++- 6 files changed, 255 insertions(+), 75 deletions(-) diff --git a/test/distributed/test_c10d.py b/test/distributed/test_c10d.py index 9d0c19bef7b3..2e8fc8854804 100644 --- a/test/distributed/test_c10d.py +++ b/test/distributed/test_c10d.py @@ -1619,6 +1619,10 @@ def test_init_no_gpus(self): c10d.ProcessGroupNCCL(store, self.rank, self.world_size) +@unittest.skipIf( + TEST_WITH_TSAN, + "TSAN is not fork-safe since we're forking in a multi-threaded environment", +) class ProcessGroupNCCLTest(TestCase): MAIN_PROCESS_RANK = 0 @@ -3828,7 +3832,6 @@ def test_multi_limit_multi_dtype(self): self.assertEqual([[0], [1], [2, 4], [3, 5]], result) -@skip_if_rocm @unittest.skipIf(TEST_WITH_TSAN, "TSAN is not fork-safe since we're forking in a multi-threaded environment") class NcclErrorHandlingTest(MultiProcessTestCase): def setUp(self): diff --git a/torch/lib/c10d/ProcessGroup.cpp b/torch/lib/c10d/ProcessGroup.cpp index 5c362a42fcf5..83035666d7e9 100644 --- a/torch/lib/c10d/ProcessGroup.cpp +++ b/torch/lib/c10d/ProcessGroup.cpp @@ -4,6 +4,62 @@ namespace c10d { +std::string opTypeToString(OpType opType) { + switch (opType) { + case OpType::BROADCAST: + return "BROADCAST"; + case OpType::ALLREDUCE: + return "ALLREDUCE"; + case OpType::ALLREDUCE_COALESCED: + return "ALLREDUCE_COALESCED"; + case OpType::REDUCE: + return "REDUCE"; + case OpType::ALLGATHER: + return "ALLGATHER"; + case OpType::ALLGATHER_BASE: + return "ALLGATHER_BASE"; + case OpType::ALLGATHER_COALESCED: + return "ALLGATHER_COALESCED"; + case OpType::GATHER: + return "GATHER"; + case OpType::SCATTER: + return "SCATTER"; + case OpType::REDUCE_SCATTER: + return "REDUCE_SCATTER"; + case OpType::ALLTOALL_BASE: + return "ALLTOALL_BASE"; + case OpType::ALLTOALL: + return "ALLTOALL"; + case OpType::SEND: + return "SEND"; + case OpType::RECV: + return "RECV"; + case OpType::RECVANYSOURCE: + return "RECVANYSOURCE"; + case OpType::BARRIER: + return "BARRIER"; + case OpType::UNKNOWN: + return "UNKNOWN"; + default: + TORCH_INTERNAL_ASSERT("Unknown op type!"); + } + return "UNKNOWN"; +} + +bool isP2POp(OpType opType) { + return opType == OpType::SEND || opType == OpType::RECV || + opType == OpType::RECVANYSOURCE; +} + +ProcessGroup::Work::Work() : rank_(-1), opType_(OpType::UNKNOWN) {} + +ProcessGroup::Work::Work(int rank, OpType opType) + : rank_(rank), opType_(opType) {} + +OpType ProcessGroup::Work::retrieveOpType() { + return opType_; +} + ProcessGroup::Work::~Work() {} bool ProcessGroup::Work::isCompleted() { diff --git a/torch/lib/c10d/ProcessGroup.hpp b/torch/lib/c10d/ProcessGroup.hpp index 59d40d2427a8..01d835d913cd 100644 --- a/torch/lib/c10d/ProcessGroup.hpp +++ b/torch/lib/c10d/ProcessGroup.hpp @@ -15,6 +15,32 @@ constexpr auto kNoTimeout = std::chrono::milliseconds(0); namespace c10d { +enum class OpType : std::uint8_t { + BROADCAST = 0, + ALLREDUCE = 1, + ALLREDUCE_COALESCED = 2, + REDUCE = 3, + ALLGATHER = 4, + ALLGATHER_BASE = 5, + ALLGATHER_COALESCED = 6, + GATHER = 7, + SCATTER = 8, + REDUCE_SCATTER = 9, + ALLTOALL_BASE = 10, + ALLTOALL = 11, + SEND = 12, + RECV = 13, + RECVANYSOURCE = 14, + BARRIER = 15, + UNKNOWN = 100, +}; + +// Converts OpType to human readable string. +std::string opTypeToString(OpType opType); + +// Whether or not an OP is an p2p op (SEND, RECV, RECVANYSOURCE) +bool isP2POp(OpType opType); + // ProcessGroup is a base class that captures collective and point to // point communication in a fixed set of processes. // @@ -39,6 +65,10 @@ class ProcessGroup { public: class Work { public: + Work(); + + Work(int rank, OpType opType); + virtual ~Work(); // Checks if request has completed. Non-blocking operation. @@ -93,6 +123,8 @@ class ProcessGroup { // work. Only NCCL backend is currently supported. virtual c10::intrusive_ptr getFuture(); + OpType retrieveOpType(); + protected: // Completes the work object and optionally sets the exception in a // thread-safe manner. Notifies all waiting condition variables as well. @@ -106,6 +138,12 @@ class ProcessGroup { std::condition_variable cv_; bool completed_ = false; std::exception_ptr exception_; + + // Current rank of the node. + const int rank_; + + // Operation type that this work object refers to. + OpType opType_; }; explicit ProcessGroup(int rank, int size); diff --git a/torch/lib/c10d/ProcessGroupNCCL.cpp b/torch/lib/c10d/ProcessGroupNCCL.cpp index 6e45b8594f9b..809dd8e07172 100644 --- a/torch/lib/c10d/ProcessGroupNCCL.cpp +++ b/torch/lib/c10d/ProcessGroupNCCL.cpp @@ -241,8 +241,22 @@ constexpr int64_t kSynchronizeBusyWaitMillis = 10; const int64_t ProcessGroupNCCL::kProcessGroupNCCLOpTimeoutMillis = 10 * 1000; thread_local uint64_t ProcessGroupNCCL::ncclActiveGroupCounter_ = 0; -ProcessGroupNCCL::WorkNCCL::WorkNCCL(const std::vector& devices) - : devices_(devices), workStartTime_(std::chrono::steady_clock::now()) { +std::ostream& operator<<( + std::ostream& output, + const ProcessGroupNCCL::WorkNCCL& workNCCL) { + return output << "WorkNCCL(" + << "OpType=" << opTypeToString(workNCCL.opType_) + << ", TensorShape=" << (*workNCCL.outputs_)[0].sizes() + << ", Timeout(ms)=" << workNCCL.opTimeout_.count() << ")"; +} + +ProcessGroupNCCL::WorkNCCL::WorkNCCL( + const std::vector& devices, + int rank, + OpType opType) + : Work(rank, opType), + devices_(devices), + workStartTime_(std::chrono::steady_clock::now()) { // Creates the CUDA event wrappers // Note: The actual events are lazily created when first recorded to with // DEFAULT_FLAGS = cudaEventDisableTiming. @@ -252,7 +266,8 @@ ProcessGroupNCCL::WorkNCCL::WorkNCCL(const std::vector& devices) } ProcessGroupNCCL::WorkNCCL::WorkNCCL(const WorkNCCL& w) - : std::enable_shared_from_this(w), + : Work(w.rank_, w.opType_), + std::enable_shared_from_this(w), devices_(w.devices_), cudaEvents_(w.cudaEvents_), ncclComms_(w.ncclComms_), @@ -375,9 +390,19 @@ void ProcessGroupNCCL::WorkNCCL::synchronizeInternal( ncclComm->ncclCommAbort(); const auto& storeKey = getNcclAbortedCommStoreKey( buildNcclUniqueIdStr(ncclComm->getNcclId())); - store_->set(storeKey, {}); - LOG(INFO) << "Wrote aborted communicator id to store: " << storeKey; + auto rankStr = std::to_string(rank_); + store_->set( + storeKey, + std::vector( + reinterpret_cast(rankStr.data()), + reinterpret_cast(rankStr.data()) + + rankStr.size())); + LOG(INFO) << "[Rank " << rank_ + << "] Wrote aborted communicator id to store: " << storeKey; } + LOG(INFO) << "[Rank " << rank_ + << "] Caught collective operation timeout for work: " + << (*this); throw std::runtime_error("Operation timed out!"); } // Check for errors and throw appropriate exception. @@ -430,7 +455,6 @@ void ProcessGroupNCCL::parseNcclAsyncErrorHandling() { auto val = std::stoi(errorHandle); if (val == 1) { asyncErrorHandling_ = true; - LOG(INFO) << "[Rank " << rank_ << "] NCCL Async Error Handling enabled."; } else if (val != 0) { throw std::runtime_error( "Invalid value for environment variable: " + @@ -483,6 +507,12 @@ ProcessGroupNCCL::ProcessGroupNCCL( if (asyncErrorHandling_) { workCleanupThread_ = std::thread(&ProcessGroupNCCL::workCleanupLoop, this); } + LOG(INFO) << "[Rank " << rank_ + << "] ProcessGroupNCCL initialized with following options:" + << "\nNCCL_ASYNC_ERROR_HANDLING: " << asyncErrorHandling_ + << "\nNCCL_BLOCKING_WAIT: " << blockingWait_ + << "\nTIMEOUT(ms): " << opTimeout_.count() + << "\nUSE_HIGH_PRIORITY_STREAM: " << isHighPriorityStream_; } ProcessGroupNCCL::~ProcessGroupNCCL() { @@ -513,12 +543,17 @@ ProcessGroupNCCL::~ProcessGroupNCCL() { void ProcessGroupNCCL::ncclCommWatchdog() { try { + LOG(INFO) << "[Rank " << rank_ << "] NCCL watchdog thread started!"; ncclCommWatchdogInternal(); - LOG(INFO) << "[Rank " << rank_ << "] NCCL watchdog thread terminated normally"; + LOG(INFO) << "[Rank " << rank_ + << "] NCCL watchdog thread terminated normally"; } catch (std::exception& e) { - LOG(INFO) << "NCCL watchdog thread terminated with exception: " << e.what(); + LOG(INFO) << "[Rank " << rank_ + << "] NCCL watchdog thread terminated with exception: " + << e.what(); } catch (...) { - LOG(INFO) << "NCCL watchdog thread terminated with unknown exception"; + LOG(INFO) << "[Rank " << rank_ + << "] NCCL watchdog thread terminated with unknown exception"; } } @@ -539,10 +574,12 @@ void ProcessGroupNCCL::ncclCommWatchdogInternal() { } if (checkForNCCLErrors(ncclComms)) { - LOG(INFO) << "Received NCCL errors for communicators in the cache"; + LOG(INFO) << "[Rank " << rank_ + << "] Received NCCL errors for communicators in the cache"; if (blockingWait_ || asyncErrorHandling_) { - LOG(INFO) << "Aborting communicators that received errors"; + LOG(INFO) << "[Rank " << rank_ + << "] Aborting communicators that received errors"; // We abort NCCL communicators that have received errors from this // thread, and exceptions are set on the corresponding work objects. // The workCleanupThread will then loop through the unfinished @@ -559,7 +596,8 @@ void ProcessGroupNCCL::ncclCommWatchdogInternal() { // a communicator the application receives an exception and its // their responsibility to destroy the process group and recreate // it to recover from errors. - abortedCommIds.emplace(buildNcclUniqueIdStr(ncclComm->getNcclId())); + abortedCommIds.emplace( + buildNcclUniqueIdStr(ncclComm->getNcclId())); } } } @@ -578,7 +616,10 @@ void ProcessGroupNCCL::ncclCommWatchdogInternal() { // Check for Timeouts in the WorkNCCL Operations, and abort all // communicators accordingly. if (work.timedOut()) { - LOG(INFO) << "[" << rank_ << "] caught collective operation timeout"; + LOG(INFO) + << "[Rank " << rank_ + << "] Watchdog caught collective operation timeout for work: " + << work; std::exception_ptr exception_ptr = std::make_exception_ptr( std::runtime_error("NCCL Operation Timed Out")); work.setException(exception_ptr); @@ -601,8 +642,15 @@ void ProcessGroupNCCL::ncclCommWatchdogInternal() { for (const auto& abortedCommId : abortedCommIds) { abortedComms_.emplace(abortedCommId); const auto& storeKey = getNcclAbortedCommStoreKey(abortedCommId); - store_->set(storeKey, {}); - LOG(INFO) << "Watchdog wrote aborted communicator id to store: " + auto rankStr = std::to_string(rank_); + store_->set( + storeKey, + std::vector( + reinterpret_cast(rankStr.data()), + reinterpret_cast(rankStr.data()) + + rankStr.size())); + LOG(INFO) << "[Rank " << rank_ + << "] Watchdog wrote aborted communicator id to store: " << storeKey; } @@ -616,7 +664,11 @@ void ProcessGroupNCCL::ncclCommWatchdogInternal() { store_->wait( {storeKey}, std::chrono::milliseconds(kWaitForAbortCommStoreKey)); - LOG(INFO) << "Found key in store: " << storeKey + auto val = store_->get(storeKey); + std::string rank(reinterpret_cast(val.data()), val.size()); + LOG(INFO) << "[Rank " << rank_ + << "] Found key in store: " << storeKey + << ", from rank: " << rank << ", aborting appropriate communicators"; // Now abort the appropriate communicators. @@ -627,7 +679,9 @@ void ProcessGroupNCCL::ncclCommWatchdogInternal() { ncclComm->ncclCommAbort(); } abortedComms_.emplace(commId); - LOG(INFO) << "Aborted communicators for key in store: " << storeKey; + LOG(INFO) << "[Rank " << rank_ + << "] Aborted communicators for key in store: " + << storeKey; } catch (std::exception& e) { VLOG(1) << "Did not find key in store: " << storeKey << ", error: " << e.what(); @@ -726,7 +780,7 @@ void ProcessGroupNCCL::broadcastUniqueNCCLID(ncclUniqueId* ncclID) { std::vector>& ProcessGroupNCCL::getNCCLComm( const std::string& devicesKey, const std::vector& devices, - NCCLCommType commType, + OpType opType, int p2pRank) { // Sanity check if (devicesKey.empty()) { @@ -755,7 +809,7 @@ std::vector>& ProcessGroupNCCL::getNCCLComm( ncclUniqueId ncclID; // For point-to-point communication, lower rank of the two will get unique id. - if (rank_ == 0 || (commType != NCCLCommType::COLL && p2pRank == 0)) { + if (rank_ == 0 || (isP2POp(opType) && p2pRank == 0)) { C10D_NCCL_CHECK(ncclGetUniqueId(&ncclID)); } @@ -793,12 +847,12 @@ std::vector>& ProcessGroupNCCL::getNCCLComm( // GPU world size and GPU rank int numRanks, rank; - if (commType == NCCLCommType::COLL) { + if (!isP2POp(opType)) { numRanks = getSize() * devices.size(); rank = getRank() * devices.size() + i; } else { - // For point-to-point operation, there are only 2 processes involved so - // the GPU rank is either 0 or 1. + // For point-to-point operation, there are only 2 processes involved so + // the GPU rank is either 0 or 1. numRanks = 2; rank = p2pRank; } @@ -816,7 +870,8 @@ std::vector>& ProcessGroupNCCL::getNCCLComm( std::lock_guard lock(mutex_); if (futureNCCLCallbackStreams_[deviceIndex] == nullptr) { futureNCCLCallbackStreams_[deviceIndex] = - std::make_shared(at::cuda::getStreamFromPool(isHighPriorityStream_)); + std::make_shared( + at::cuda::getStreamFromPool(isHighPriorityStream_)); } } @@ -947,8 +1002,10 @@ std::vector flatten_for_scatter_gather( } // namespace std::shared_ptr ProcessGroupNCCL::initWork( - std::vector devices) { - return std::make_shared(devices); + std::vector devices, + int rank, + OpType opType) { + return std::make_shared(devices, rank, opType); } std::vector ProcessGroupNCCL::WorkNCCL::result() { @@ -982,7 +1039,8 @@ void ProcessGroupNCCL::workEnqueue( } } ProcessGroupNCCL::Options::Options() - : opTimeout(kProcessGroupNCCLOpTimeoutMillis), isHighPriorityStream(false) {} + : opTimeout(kProcessGroupNCCLOpTimeoutMillis), + isHighPriorityStream(false) {} template std::shared_ptr ProcessGroupNCCL::collective( @@ -990,16 +1048,17 @@ std::shared_ptr ProcessGroupNCCL::collective( std::vector& outputs, Fn fn, PreProcess pre, - PostProcess post) { + PostProcess post, + OpType opType) { const auto devices = getDeviceList(inputs); const auto key = getKeyFromDevices(devices); - auto& ncclComms = getNCCLComm(key, devices); + auto& ncclComms = getNCCLComm(key, devices, opType); // First let NCCL streams wait for input tensors allocation streams syncStreams(devices, ncclEvents_[key], ncclStreams_[key]); // Work itself will create the CUDA events on all GPUs of tensors - auto work = initWork(devices); + auto work = initWork(devices, rank_, opType); // Store references to outputs and futureNCCLCallbackStream to be used by // WorkNCCL::getFuture. @@ -1043,11 +1102,13 @@ std::shared_ptr ProcessGroupNCCL::collective( at::cuda::CUDAStream& ncclStream = ncclStreams_[key][i]; (*work->cudaEvents_)[i].record(ncclStream); work->ncclComms_[i] = ncclComms[i]; - work->blockingWait_ = blockingWait_; - work->opTimeout_ = opTimeout_; - work->store_ = store_; } + // Set appropriate work parameters. + work->blockingWait_ = blockingWait_; + work->opTimeout_ = opTimeout_; + work->store_ = store_; + if (asyncErrorHandling_) { workEnqueue(work); } @@ -1060,21 +1121,21 @@ std::shared_ptr ProcessGroupNCCL::pointToPoint( std::vector& tensors, Fn fn, int peer, - NCCLCommType commType, + OpType opType, PreProcess pre, PostProcess post) { const auto devices = getDeviceList(tensors); const auto key = getKeySendRecv(rank_, peer); int p2pRank = rank_ < peer ? 0 : 1; - auto& ncclComms = getNCCLComm(key, devices, commType, p2pRank); + auto& ncclComms = getNCCLComm(key, devices, opType, p2pRank); // First let NCCL streams wait for input tensors allocation streams syncStreams(devices, ncclEvents_[key], ncclStreams_[key]); // Work itself will create the CUDA events on all GPUs of tensors - auto work = initWork(devices); + auto work = initWork(devices, rank_, opType); - if (commType == NCCLCommType::RECV) { + if (opType == OpType::RECV) { // Store references to outputs and futureNCCLCallbackStream to be used by // WorkNCCL::getFuture. work->outputs_ = std::make_shared>(tensors); @@ -1130,13 +1191,15 @@ template std::shared_ptr ProcessGroupNCCL::collective( std::vector& inputs, std::vector& outputs, - Fn fn) { + Fn fn, + OpType opType) { return collective( inputs, outputs, fn, [](std::vector&) {}, - [](std::vector&) {}); + [](std::vector&) {}, + opType); } template @@ -1144,12 +1207,12 @@ std::shared_ptr ProcessGroupNCCL::pointToPoint( std::vector& tensor, Fn fn, int peer, - NCCLCommType type) { + OpType opType) { return pointToPoint( tensor, fn, peer, - type, + opType, [](std::vector&) {}, [](std::vector&) {}); } @@ -1174,7 +1237,8 @@ std::shared_ptr ProcessGroupNCCL::allreduce( getNcclReduceOp(opts.reduceOp, input), comm, stream.stream()); - }); + }, + OpType::ALLREDUCE); } std::shared_ptr ProcessGroupNCCL::allreduce_coalesced( @@ -1204,7 +1268,8 @@ std::shared_ptr ProcessGroupNCCL::broadcast( root, comm, stream.stream()); - }); + }, + OpType::BROADCAST); } std::shared_ptr ProcessGroupNCCL::reduce( @@ -1229,7 +1294,8 @@ std::shared_ptr ProcessGroupNCCL::reduce( root, comm, stream.stream()); - }); + }, + OpType::REDUCE); } std::shared_ptr ProcessGroupNCCL::allgather( @@ -1272,7 +1338,8 @@ std::shared_ptr ProcessGroupNCCL::allgather( outputTensors[i][j].copy_(outputFlattened[i][j], true); } } - }); + }, + OpType::ALLGATHER); } std::shared_ptr ProcessGroupNCCL::allgather_coalesced( @@ -1324,7 +1391,8 @@ std::shared_ptr ProcessGroupNCCL::reduce_scatter( } } }, - [&](std::vector& ncclStreams) {}); + [&](std::vector& ncclStreams) {}, + OpType::REDUCE_SCATTER); } std::shared_ptr ProcessGroupNCCL::barrier( @@ -1394,7 +1462,8 @@ std::shared_ptr ProcessGroupNCCL::alltoall_base( getNcclDataType(input.scalar_type()), comm, stream.stream()); - }); + }, + OpType::ALLTOALL_BASE); } else { c10d::checkSplitSizes(inputSplitSizes, inputTensor, size_); c10d::checkSplitSizes(outputSplitSizes, outputTensor, size_); @@ -1426,7 +1495,8 @@ std::shared_ptr ProcessGroupNCCL::alltoall_base( getNcclDataType(input.scalar_type()), comm, stream.stream()); - }); + }, + OpType::ALLTOALL_BASE); } } @@ -1450,7 +1520,7 @@ std::shared_ptr ProcessGroupNCCL::send( stream.stream()); }, dstRank, - NCCLCommType::SEND); + OpType::SEND); return ret; } @@ -1459,7 +1529,7 @@ std::shared_ptr ProcessGroupNCCL::recv( int srcRank, int /* unused */) { check_gpu_tensors(tensors); - auto ret= pointToPoint( + auto ret = pointToPoint( tensors, [&](at::Tensor& output, ncclComm_t comm, @@ -1474,7 +1544,7 @@ std::shared_ptr ProcessGroupNCCL::recv( stream.stream()); }, srcRank, - NCCLCommType::RECV); + OpType::RECV); return ret; } #else diff --git a/torch/lib/c10d/ProcessGroupNCCL.hpp b/torch/lib/c10d/ProcessGroupNCCL.hpp index b8b3d5aabd35..13b8c72b318a 100644 --- a/torch/lib/c10d/ProcessGroupNCCL.hpp +++ b/torch/lib/c10d/ProcessGroupNCCL.hpp @@ -1,5 +1,6 @@ #pragma once +#include #include #include #include @@ -23,13 +24,6 @@ constexpr const char* NCCL_BLOCKING_WAIT = "NCCL_BLOCKING_WAIT"; // Handling with NCCL. constexpr const char* NCCL_ASYNC_ERROR_HANDLING = "NCCL_ASYNC_ERROR_HANDLING"; -// NCCL Commmunication type -enum class NCCLCommType : std::uint8_t { - SEND = 0, - RECV, - COLL, -}; - // ProcessGroupNCCL implements NCCL bindings for c10d. // // All functions of the class are expected to be called in the same order @@ -71,7 +65,7 @@ class ProcessGroupNCCL : public ProcessGroup { public std::enable_shared_from_this { public: // Constructor takes a list of CUDA devices - WorkNCCL(const std::vector& devices); + WorkNCCL(const std::vector& devices, int rank, OpType opType); // Copy constructor doing partial copy without outputs_. Cleanup thread // monitors and removes finished works. However it will deadlock when // destructs outputs_ tensors who are view tensors in autograd graph. @@ -147,6 +141,10 @@ class ProcessGroupNCCL : public ProcessGroup { virtual std::exception_ptr checkForNCCLErrors( const std::vector>& ncclComms) const; + friend std::ostream& operator<<( + std::ostream& output, + const WorkNCCL& workNCCL); + private: // Helper function for synchronize void synchronizeInternal(std::chrono::milliseconds timeout); @@ -166,6 +164,7 @@ class ProcessGroupNCCL : public ProcessGroup { // Store a reference to NCCL collective's outputs to be used by getFuture. std::shared_ptr> outputs_; + // Store streams that run FutureNCCL then callbacks. std::vector> futureNCCLCallbackStreams_; @@ -467,7 +466,7 @@ class ProcessGroupNCCL : public ProcessGroup { std::vector>& getNCCLComm( const std::string& devicesKey, const std::vector& devices, - NCCLCommType commType = NCCLCommType::COLL, + OpType opType, int p2pRank = 0); // Wrapper method which can be overridden for tests. @@ -475,7 +474,9 @@ class ProcessGroupNCCL : public ProcessGroup { const std::vector>& ncclComms); virtual std::shared_ptr initWork( - std::vector devices); + std::vector devices, + int rank, + OpType opType); private: // Helper that encapsulates work shared across all collective communication @@ -488,14 +489,16 @@ class ProcessGroupNCCL : public ProcessGroup { std::shared_ptr collective( std::vector& input, std::vector& output, - Fn fn); + Fn fn, + OpType opType); template std::shared_ptr collective( std::vector& input, std::vector& output, Fn fn, PreProcess pre, - PostProcess post); + PostProcess post, + OpType opType); // Helper that encapsulates work shared across point-to-point communication // primitives. It is the same structure as the helper used for collective @@ -505,13 +508,13 @@ class ProcessGroupNCCL : public ProcessGroup { std::vector& tensor, Fn fn, int peer, - NCCLCommType commType); + OpType opType); template std::shared_ptr pointToPoint( std::vector& tensor, Fn fn, int peer, - NCCLCommType commType, + OpType opType, PreProcess pre, PostProcess post); @@ -537,8 +540,8 @@ class ProcessGroupNCCL : public ProcessGroup { // accordingly. void parseNcclBlockingWait(); - // Reads the NCCL_ASYNC_ERROR_HANDLING environment variable and sets asyncErrorHandling_ - // accordingly. + // Reads the NCCL_ASYNC_ERROR_HANDLING environment variable and sets + // asyncErrorHandling_ accordingly. void parseNcclAsyncErrorHandling(); void workCleanupLoop(); diff --git a/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp b/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp index 93f633938e18..0df197d17cbb 100644 --- a/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp +++ b/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp @@ -16,8 +16,10 @@ class WorkNCCLSimulateErrors : public c10d::ProcessGroupNCCL::WorkNCCL { public: WorkNCCLSimulateErrors( const std::vector& devices, - bool simulate_error) - : WorkNCCL(devices), simulate_error_(simulate_error) {} + bool simulate_error, + int rank, + c10d::OpType opType) + : WorkNCCL(devices, rank, opType), simulate_error_(simulate_error) {} std::exception_ptr checkForNCCLErrors( const std::vector>& ncclComms) @@ -55,8 +57,11 @@ class ProcessGroupNCCLSimulateErrors : public c10d::ProcessGroupNCCL { } std::shared_ptr initWork( - std::vector devices) override { - return std::make_shared(devices, simulate_error_); + std::vector devices, + int rank, + c10d::OpType opType) override { + return std::make_shared( + devices, simulate_error_, rank, opType); } size_t getNCCLCommCacheSize() { @@ -79,8 +84,11 @@ class WorkNCCLTimedoutErrors : public c10d::ProcessGroupNCCL::WorkNCCL { public: WorkNCCLTimedoutErrors( const std::vector& devices, - bool set_timedout_error) - : WorkNCCL(devices), set_timedout_error_(set_timedout_error) {} + bool set_timedout_error, + int rank, + c10d::OpType opType) + : WorkNCCL(devices, rank, opType), + set_timedout_error_(set_timedout_error) {} private: bool isCompleted() override { @@ -105,9 +113,11 @@ class ProcessGroupNCCLTimedOutErrors : public ProcessGroupNCCLSimulateErrors { set_timedout_error_(false) {} std::shared_ptr initWork( - std::vector devices) override { + std::vector devices, + int rank, + c10d::OpType opType) override { return std::make_shared( - devices, set_timedout_error_); + devices, set_timedout_error_, rank, opType); } void set_timedout_error() { From 0927e02a6aea1cea8b7b324d2b9fca5dcda494d3 Mon Sep 17 00:00:00 2001 From: Hao Lu Date: Wed, 7 Oct 2020 14:04:57 -0700 Subject: [PATCH 40/69] [caffe2] Do not run RemoveOpsByType on recurrent networks (#45986) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45986 Recurrent networks have subnets that are not well supported by `RemoveOpsByType`. Here we exclude recurrent networks by adding the same check as in memonger. Test Plan: ``` buck test //caffe2/caffe2/fb/predictor:black_box_predictor_test ``` AdIndexer canary for sanity check: https://www.internalfb.com/intern/ads/canary/430059485214766620 Differential Revision: D24167284 fbshipit-source-id: fa90d1c1f34af334a599d879af09d4c0bf7c27bd --- caffe2/predictor/transforms.cc | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/caffe2/predictor/transforms.cc b/caffe2/predictor/transforms.cc index 03653d8ea2a4..72a6098b7e95 100644 --- a/caffe2/predictor/transforms.cc +++ b/caffe2/predictor/transforms.cc @@ -90,7 +90,7 @@ void RenameOutputs( void RenameInputsInChildren( const string& from, const string& to, - std::shared_ptr net, + caffe2::NetDef* net, std::unordered_map>& children) { VLOG(2) << "RenameInputsInChildren (from=" << from << ", to=" << to << ")"; if (children.count(from) == 0) { @@ -106,7 +106,7 @@ void RenameInputsInChildren( void RenameOutputInParents( const std::string& from, const std::string& to, - std::shared_ptr net, + caffe2::NetDef* net, std::unordered_map>& parents) { VLOG(2) << "RenameOutputInParents (from=" << from << ", to=" << to << ")"; if (parents.count(from) == 0) { @@ -225,7 +225,13 @@ bool FoundOpCandidate( // extra complexity is handled in FoundOpCandidate. void RemoveOpsByType(InferenceGraph& graph, const std::string& op_type) { int num_removed = 0; - std::shared_ptr net = graph.predict_net_def; + NetDef* net = graph.predict_net_def.get(); + for (auto& op : net->op()) { + if (op.type() == "RecurrentNetwork") { + LOG(INFO) << "RemoveOpsByType does not support RecurrentNetwork yet"; + return; + } + } std::unordered_set inputs( graph.input_names.begin(), graph.input_names.end()); @@ -239,7 +245,7 @@ void RemoveOpsByType(InferenceGraph& graph, const std::string& op_type) { for (const auto& o : graph.output_names) { net->add_external_output(o); } - onnx::SsaRewrite(nullptr, net.get()); + onnx::SsaRewrite(nullptr, net); // clear external_outputs net->mutable_external_output()->Clear(); graph.predictor_net_ssa_rewritten = true; From ce82b522c8f343419a1fbb1bf2405b89121b0335 Mon Sep 17 00:00:00 2001 From: Andy Zhang Date: Wed, 7 Oct 2020 15:01:06 -0700 Subject: [PATCH 41/69] Define objects using classes instead of namedtuples in torch.utils.data._utils.worker (#45870) Summary: This PR fixes a bug when torch is used with pyspark, by converting namedtuples in `torch.utils.data._utils.worker` into classes. Before this PR, creating an IterableDataset and then running `list(torch.utils.data.DataLoader(MyIterableDataset(...), num_workers=2)))` will not terminate, if pyspark is also being used. This is because pyspark hijacks namedtuples to make them pickleable ([see here](https://github.com/apache/spark/blob/master/python/pyspark/serializers.py#L370)). So `_IterableDatasetStopIteration` would be modified, and then the check at [this line in dataloader.py](https://github.com/pytorch/pytorch/blob/5472426b9f85c8107aade5d256cf4cde572eab5c/torch/utils/data/dataloader.py#L1072) is never true. Converting the namedtuples to classes avoids this hijack and allows the iteration to correctly stop when signaled. Pull Request resolved: https://github.com/pytorch/pytorch/pull/45870 Reviewed By: ngimel Differential Revision: D24162748 Pulled By: albanD fbshipit-source-id: 52f009784500fa594b2bbd15a8b2e486e00c37fb --- torch/utils/data/_utils/worker.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/torch/utils/data/_utils/worker.py b/torch/utils/data/_utils/worker.py index 8802f40ecdb9..7a53d61feae5 100644 --- a/torch/utils/data/_utils/worker.py +++ b/torch/utils/data/_utils/worker.py @@ -7,7 +7,7 @@ import torch import random import os -from collections import namedtuple +from dataclasses import dataclass from torch._six import queue from torch._utils import ExceptionWrapper from typing import Union @@ -110,10 +110,14 @@ def get_worker_info(): r"""Dummy class used to signal the end of an IterableDataset""" -_IterableDatasetStopIteration = namedtuple('_IterableDatasetStopIteration', ['worker_id']) +@dataclass(frozen=True) +class _IterableDatasetStopIteration(object): + worker_id: int r"""Dummy class used to resume the fetching when worker reuse is enabled""" -_ResumeIteration = namedtuple('_ResumeIteration', []) +@dataclass(frozen=True) +class _ResumeIteration(object): + pass def _worker_loop(dataset_kind, dataset, index_queue, data_queue, done_event, auto_collation, collate_fn, drop_last, seed, init_fn, worker_id, From 505be08c75d5c5451586a3069124d324a483d6ef Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Wed, 7 Oct 2020 15:08:05 -0700 Subject: [PATCH 42/69] [dist_optim] serialize compilation when creating dist_optim (#45871) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45871 Attempt to fix https://github.com/pytorch/pytorch/issues/45845 Test Plan: Imported from OSS Reviewed By: pritamdamania87 Differential Revision: D24125209 Pulled By: wanchaol fbshipit-source-id: e3697dd6ef107d8153d2a82d78a17c66d109b4fa --- torch/distributed/optim/optimizer.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/torch/distributed/optim/optimizer.py b/torch/distributed/optim/optimizer.py index bb04e2dde3aa..c7f8e3236776 100644 --- a/torch/distributed/optim/optimizer.py +++ b/torch/distributed/optim/optimizer.py @@ -3,6 +3,7 @@ import torch.distributed.rpc as rpc import torch.optim as optim import torch.jit as jit +import torch.nn as nn from torch import Tensor from torch.distributed.rpc import RRef from .functional_adagrad import _FunctionalAdagrad @@ -28,7 +29,12 @@ class _ScriptLocalOptimizerInterface(object): def step(self, autograd_ctx_id: int) -> None: pass -class _ScriptLocalOptimizer(jit.ScriptModule): +class _ScriptLocalOptimizer(nn.Module): + # TorchScript does not support multithread concurrent compiling. + # request_callback might invoke concurrent compiling, so we + # serialize the compiling with a lock + compile_lock = Lock() + def __init__(self, optim_cls, local_params_rref, *args, **kwargs): super().__init__() self._local_params = [rref.local_value() for rref in local_params_rref] @@ -37,7 +43,7 @@ def __init__(self, optim_cls, local_params_rref, *args, **kwargs): *args, **kwargs) - @jit.script_method + @jit.export def step(self, autograd_ctx_id: int): all_local_grads = dist_autograd.get_gradients(autograd_ctx_id) # apply functional optimizer step with a list of gradients @@ -49,6 +55,8 @@ def step(self, autograd_ctx_id: int): self.optim.step(grads) +# TODO (wanchaol): remove/merge this with ScriptLocalOptimizer once +# we have converted all to functional optimizer in distributed.optim class _LocalOptimizer(object): # Ideally we would only need to share a lock for instances of # _LocalOptimizer that deal with the same parameters. We are @@ -87,8 +95,12 @@ def _local_optimizer_step(local_optim_rref, autograd_ctx_id): # new/step functions combined with _ScriptLocalOptimizer to provide GIL-free optimizer def _new_script_local_optimizer(optim_cls, local_params_rref, *args, **kwargs): - return rpc.RRef( - _ScriptLocalOptimizer(optim_cls, local_params_rref, *args, **kwargs), _ScriptLocalOptimizerInterface) + optim = _ScriptLocalOptimizer(optim_cls, local_params_rref, *args, **kwargs) + + with _ScriptLocalOptimizer.compile_lock: + script_optim = jit.script(optim) + return rpc.RRef( + script_optim, _ScriptLocalOptimizerInterface) @jit.script def _script_local_optimizer_step( From de0d0bd5ee50e230e261d8a241c84185436fdc1c Mon Sep 17 00:00:00 2001 From: Natalia Gimelshein Date: Wed, 7 Oct 2020 16:39:33 -0700 Subject: [PATCH 43/69] Revert D24093032: Improve logging in ProcessGroupNCCL for debugging purposes. Test Plan: revert-hammer Differential Revision: D24093032 (https://github.com/pytorch/pytorch/commit/c8d76ff7dc21d426c7f2851803071820865a828f) Original commit changeset: 240b03562f8c fbshipit-source-id: dab7d54a5ba517bb308a1825b0d63ed146e5269d --- test/distributed/test_c10d.py | 5 +- torch/lib/c10d/ProcessGroup.cpp | 56 ------ torch/lib/c10d/ProcessGroup.hpp | 38 ---- torch/lib/c10d/ProcessGroupNCCL.cpp | 170 ++++++------------ torch/lib/c10d/ProcessGroupNCCL.hpp | 35 ++-- .../c10d/test/ProcessGroupNCCLErrorsTest.cpp | 26 +-- 6 files changed, 75 insertions(+), 255 deletions(-) diff --git a/test/distributed/test_c10d.py b/test/distributed/test_c10d.py index 2e8fc8854804..9d0c19bef7b3 100644 --- a/test/distributed/test_c10d.py +++ b/test/distributed/test_c10d.py @@ -1619,10 +1619,6 @@ def test_init_no_gpus(self): c10d.ProcessGroupNCCL(store, self.rank, self.world_size) -@unittest.skipIf( - TEST_WITH_TSAN, - "TSAN is not fork-safe since we're forking in a multi-threaded environment", -) class ProcessGroupNCCLTest(TestCase): MAIN_PROCESS_RANK = 0 @@ -3832,6 +3828,7 @@ def test_multi_limit_multi_dtype(self): self.assertEqual([[0], [1], [2, 4], [3, 5]], result) +@skip_if_rocm @unittest.skipIf(TEST_WITH_TSAN, "TSAN is not fork-safe since we're forking in a multi-threaded environment") class NcclErrorHandlingTest(MultiProcessTestCase): def setUp(self): diff --git a/torch/lib/c10d/ProcessGroup.cpp b/torch/lib/c10d/ProcessGroup.cpp index 83035666d7e9..5c362a42fcf5 100644 --- a/torch/lib/c10d/ProcessGroup.cpp +++ b/torch/lib/c10d/ProcessGroup.cpp @@ -4,62 +4,6 @@ namespace c10d { -std::string opTypeToString(OpType opType) { - switch (opType) { - case OpType::BROADCAST: - return "BROADCAST"; - case OpType::ALLREDUCE: - return "ALLREDUCE"; - case OpType::ALLREDUCE_COALESCED: - return "ALLREDUCE_COALESCED"; - case OpType::REDUCE: - return "REDUCE"; - case OpType::ALLGATHER: - return "ALLGATHER"; - case OpType::ALLGATHER_BASE: - return "ALLGATHER_BASE"; - case OpType::ALLGATHER_COALESCED: - return "ALLGATHER_COALESCED"; - case OpType::GATHER: - return "GATHER"; - case OpType::SCATTER: - return "SCATTER"; - case OpType::REDUCE_SCATTER: - return "REDUCE_SCATTER"; - case OpType::ALLTOALL_BASE: - return "ALLTOALL_BASE"; - case OpType::ALLTOALL: - return "ALLTOALL"; - case OpType::SEND: - return "SEND"; - case OpType::RECV: - return "RECV"; - case OpType::RECVANYSOURCE: - return "RECVANYSOURCE"; - case OpType::BARRIER: - return "BARRIER"; - case OpType::UNKNOWN: - return "UNKNOWN"; - default: - TORCH_INTERNAL_ASSERT("Unknown op type!"); - } - return "UNKNOWN"; -} - -bool isP2POp(OpType opType) { - return opType == OpType::SEND || opType == OpType::RECV || - opType == OpType::RECVANYSOURCE; -} - -ProcessGroup::Work::Work() : rank_(-1), opType_(OpType::UNKNOWN) {} - -ProcessGroup::Work::Work(int rank, OpType opType) - : rank_(rank), opType_(opType) {} - -OpType ProcessGroup::Work::retrieveOpType() { - return opType_; -} - ProcessGroup::Work::~Work() {} bool ProcessGroup::Work::isCompleted() { diff --git a/torch/lib/c10d/ProcessGroup.hpp b/torch/lib/c10d/ProcessGroup.hpp index 01d835d913cd..59d40d2427a8 100644 --- a/torch/lib/c10d/ProcessGroup.hpp +++ b/torch/lib/c10d/ProcessGroup.hpp @@ -15,32 +15,6 @@ constexpr auto kNoTimeout = std::chrono::milliseconds(0); namespace c10d { -enum class OpType : std::uint8_t { - BROADCAST = 0, - ALLREDUCE = 1, - ALLREDUCE_COALESCED = 2, - REDUCE = 3, - ALLGATHER = 4, - ALLGATHER_BASE = 5, - ALLGATHER_COALESCED = 6, - GATHER = 7, - SCATTER = 8, - REDUCE_SCATTER = 9, - ALLTOALL_BASE = 10, - ALLTOALL = 11, - SEND = 12, - RECV = 13, - RECVANYSOURCE = 14, - BARRIER = 15, - UNKNOWN = 100, -}; - -// Converts OpType to human readable string. -std::string opTypeToString(OpType opType); - -// Whether or not an OP is an p2p op (SEND, RECV, RECVANYSOURCE) -bool isP2POp(OpType opType); - // ProcessGroup is a base class that captures collective and point to // point communication in a fixed set of processes. // @@ -65,10 +39,6 @@ class ProcessGroup { public: class Work { public: - Work(); - - Work(int rank, OpType opType); - virtual ~Work(); // Checks if request has completed. Non-blocking operation. @@ -123,8 +93,6 @@ class ProcessGroup { // work. Only NCCL backend is currently supported. virtual c10::intrusive_ptr getFuture(); - OpType retrieveOpType(); - protected: // Completes the work object and optionally sets the exception in a // thread-safe manner. Notifies all waiting condition variables as well. @@ -138,12 +106,6 @@ class ProcessGroup { std::condition_variable cv_; bool completed_ = false; std::exception_ptr exception_; - - // Current rank of the node. - const int rank_; - - // Operation type that this work object refers to. - OpType opType_; }; explicit ProcessGroup(int rank, int size); diff --git a/torch/lib/c10d/ProcessGroupNCCL.cpp b/torch/lib/c10d/ProcessGroupNCCL.cpp index 809dd8e07172..6e45b8594f9b 100644 --- a/torch/lib/c10d/ProcessGroupNCCL.cpp +++ b/torch/lib/c10d/ProcessGroupNCCL.cpp @@ -241,22 +241,8 @@ constexpr int64_t kSynchronizeBusyWaitMillis = 10; const int64_t ProcessGroupNCCL::kProcessGroupNCCLOpTimeoutMillis = 10 * 1000; thread_local uint64_t ProcessGroupNCCL::ncclActiveGroupCounter_ = 0; -std::ostream& operator<<( - std::ostream& output, - const ProcessGroupNCCL::WorkNCCL& workNCCL) { - return output << "WorkNCCL(" - << "OpType=" << opTypeToString(workNCCL.opType_) - << ", TensorShape=" << (*workNCCL.outputs_)[0].sizes() - << ", Timeout(ms)=" << workNCCL.opTimeout_.count() << ")"; -} - -ProcessGroupNCCL::WorkNCCL::WorkNCCL( - const std::vector& devices, - int rank, - OpType opType) - : Work(rank, opType), - devices_(devices), - workStartTime_(std::chrono::steady_clock::now()) { +ProcessGroupNCCL::WorkNCCL::WorkNCCL(const std::vector& devices) + : devices_(devices), workStartTime_(std::chrono::steady_clock::now()) { // Creates the CUDA event wrappers // Note: The actual events are lazily created when first recorded to with // DEFAULT_FLAGS = cudaEventDisableTiming. @@ -266,8 +252,7 @@ ProcessGroupNCCL::WorkNCCL::WorkNCCL( } ProcessGroupNCCL::WorkNCCL::WorkNCCL(const WorkNCCL& w) - : Work(w.rank_, w.opType_), - std::enable_shared_from_this(w), + : std::enable_shared_from_this(w), devices_(w.devices_), cudaEvents_(w.cudaEvents_), ncclComms_(w.ncclComms_), @@ -390,19 +375,9 @@ void ProcessGroupNCCL::WorkNCCL::synchronizeInternal( ncclComm->ncclCommAbort(); const auto& storeKey = getNcclAbortedCommStoreKey( buildNcclUniqueIdStr(ncclComm->getNcclId())); - auto rankStr = std::to_string(rank_); - store_->set( - storeKey, - std::vector( - reinterpret_cast(rankStr.data()), - reinterpret_cast(rankStr.data()) + - rankStr.size())); - LOG(INFO) << "[Rank " << rank_ - << "] Wrote aborted communicator id to store: " << storeKey; + store_->set(storeKey, {}); + LOG(INFO) << "Wrote aborted communicator id to store: " << storeKey; } - LOG(INFO) << "[Rank " << rank_ - << "] Caught collective operation timeout for work: " - << (*this); throw std::runtime_error("Operation timed out!"); } // Check for errors and throw appropriate exception. @@ -455,6 +430,7 @@ void ProcessGroupNCCL::parseNcclAsyncErrorHandling() { auto val = std::stoi(errorHandle); if (val == 1) { asyncErrorHandling_ = true; + LOG(INFO) << "[Rank " << rank_ << "] NCCL Async Error Handling enabled."; } else if (val != 0) { throw std::runtime_error( "Invalid value for environment variable: " + @@ -507,12 +483,6 @@ ProcessGroupNCCL::ProcessGroupNCCL( if (asyncErrorHandling_) { workCleanupThread_ = std::thread(&ProcessGroupNCCL::workCleanupLoop, this); } - LOG(INFO) << "[Rank " << rank_ - << "] ProcessGroupNCCL initialized with following options:" - << "\nNCCL_ASYNC_ERROR_HANDLING: " << asyncErrorHandling_ - << "\nNCCL_BLOCKING_WAIT: " << blockingWait_ - << "\nTIMEOUT(ms): " << opTimeout_.count() - << "\nUSE_HIGH_PRIORITY_STREAM: " << isHighPriorityStream_; } ProcessGroupNCCL::~ProcessGroupNCCL() { @@ -543,17 +513,12 @@ ProcessGroupNCCL::~ProcessGroupNCCL() { void ProcessGroupNCCL::ncclCommWatchdog() { try { - LOG(INFO) << "[Rank " << rank_ << "] NCCL watchdog thread started!"; ncclCommWatchdogInternal(); - LOG(INFO) << "[Rank " << rank_ - << "] NCCL watchdog thread terminated normally"; + LOG(INFO) << "[Rank " << rank_ << "] NCCL watchdog thread terminated normally"; } catch (std::exception& e) { - LOG(INFO) << "[Rank " << rank_ - << "] NCCL watchdog thread terminated with exception: " - << e.what(); + LOG(INFO) << "NCCL watchdog thread terminated with exception: " << e.what(); } catch (...) { - LOG(INFO) << "[Rank " << rank_ - << "] NCCL watchdog thread terminated with unknown exception"; + LOG(INFO) << "NCCL watchdog thread terminated with unknown exception"; } } @@ -574,12 +539,10 @@ void ProcessGroupNCCL::ncclCommWatchdogInternal() { } if (checkForNCCLErrors(ncclComms)) { - LOG(INFO) << "[Rank " << rank_ - << "] Received NCCL errors for communicators in the cache"; + LOG(INFO) << "Received NCCL errors for communicators in the cache"; if (blockingWait_ || asyncErrorHandling_) { - LOG(INFO) << "[Rank " << rank_ - << "] Aborting communicators that received errors"; + LOG(INFO) << "Aborting communicators that received errors"; // We abort NCCL communicators that have received errors from this // thread, and exceptions are set on the corresponding work objects. // The workCleanupThread will then loop through the unfinished @@ -596,8 +559,7 @@ void ProcessGroupNCCL::ncclCommWatchdogInternal() { // a communicator the application receives an exception and its // their responsibility to destroy the process group and recreate // it to recover from errors. - abortedCommIds.emplace( - buildNcclUniqueIdStr(ncclComm->getNcclId())); + abortedCommIds.emplace(buildNcclUniqueIdStr(ncclComm->getNcclId())); } } } @@ -616,10 +578,7 @@ void ProcessGroupNCCL::ncclCommWatchdogInternal() { // Check for Timeouts in the WorkNCCL Operations, and abort all // communicators accordingly. if (work.timedOut()) { - LOG(INFO) - << "[Rank " << rank_ - << "] Watchdog caught collective operation timeout for work: " - << work; + LOG(INFO) << "[" << rank_ << "] caught collective operation timeout"; std::exception_ptr exception_ptr = std::make_exception_ptr( std::runtime_error("NCCL Operation Timed Out")); work.setException(exception_ptr); @@ -642,15 +601,8 @@ void ProcessGroupNCCL::ncclCommWatchdogInternal() { for (const auto& abortedCommId : abortedCommIds) { abortedComms_.emplace(abortedCommId); const auto& storeKey = getNcclAbortedCommStoreKey(abortedCommId); - auto rankStr = std::to_string(rank_); - store_->set( - storeKey, - std::vector( - reinterpret_cast(rankStr.data()), - reinterpret_cast(rankStr.data()) + - rankStr.size())); - LOG(INFO) << "[Rank " << rank_ - << "] Watchdog wrote aborted communicator id to store: " + store_->set(storeKey, {}); + LOG(INFO) << "Watchdog wrote aborted communicator id to store: " << storeKey; } @@ -664,11 +616,7 @@ void ProcessGroupNCCL::ncclCommWatchdogInternal() { store_->wait( {storeKey}, std::chrono::milliseconds(kWaitForAbortCommStoreKey)); - auto val = store_->get(storeKey); - std::string rank(reinterpret_cast(val.data()), val.size()); - LOG(INFO) << "[Rank " << rank_ - << "] Found key in store: " << storeKey - << ", from rank: " << rank + LOG(INFO) << "Found key in store: " << storeKey << ", aborting appropriate communicators"; // Now abort the appropriate communicators. @@ -679,9 +627,7 @@ void ProcessGroupNCCL::ncclCommWatchdogInternal() { ncclComm->ncclCommAbort(); } abortedComms_.emplace(commId); - LOG(INFO) << "[Rank " << rank_ - << "] Aborted communicators for key in store: " - << storeKey; + LOG(INFO) << "Aborted communicators for key in store: " << storeKey; } catch (std::exception& e) { VLOG(1) << "Did not find key in store: " << storeKey << ", error: " << e.what(); @@ -780,7 +726,7 @@ void ProcessGroupNCCL::broadcastUniqueNCCLID(ncclUniqueId* ncclID) { std::vector>& ProcessGroupNCCL::getNCCLComm( const std::string& devicesKey, const std::vector& devices, - OpType opType, + NCCLCommType commType, int p2pRank) { // Sanity check if (devicesKey.empty()) { @@ -809,7 +755,7 @@ std::vector>& ProcessGroupNCCL::getNCCLComm( ncclUniqueId ncclID; // For point-to-point communication, lower rank of the two will get unique id. - if (rank_ == 0 || (isP2POp(opType) && p2pRank == 0)) { + if (rank_ == 0 || (commType != NCCLCommType::COLL && p2pRank == 0)) { C10D_NCCL_CHECK(ncclGetUniqueId(&ncclID)); } @@ -847,12 +793,12 @@ std::vector>& ProcessGroupNCCL::getNCCLComm( // GPU world size and GPU rank int numRanks, rank; - if (!isP2POp(opType)) { + if (commType == NCCLCommType::COLL) { numRanks = getSize() * devices.size(); rank = getRank() * devices.size() + i; } else { - // For point-to-point operation, there are only 2 processes involved so - // the GPU rank is either 0 or 1. + // For point-to-point operation, there are only 2 processes involved so + // the GPU rank is either 0 or 1. numRanks = 2; rank = p2pRank; } @@ -870,8 +816,7 @@ std::vector>& ProcessGroupNCCL::getNCCLComm( std::lock_guard lock(mutex_); if (futureNCCLCallbackStreams_[deviceIndex] == nullptr) { futureNCCLCallbackStreams_[deviceIndex] = - std::make_shared( - at::cuda::getStreamFromPool(isHighPriorityStream_)); + std::make_shared(at::cuda::getStreamFromPool(isHighPriorityStream_)); } } @@ -1002,10 +947,8 @@ std::vector flatten_for_scatter_gather( } // namespace std::shared_ptr ProcessGroupNCCL::initWork( - std::vector devices, - int rank, - OpType opType) { - return std::make_shared(devices, rank, opType); + std::vector devices) { + return std::make_shared(devices); } std::vector ProcessGroupNCCL::WorkNCCL::result() { @@ -1039,8 +982,7 @@ void ProcessGroupNCCL::workEnqueue( } } ProcessGroupNCCL::Options::Options() - : opTimeout(kProcessGroupNCCLOpTimeoutMillis), - isHighPriorityStream(false) {} + : opTimeout(kProcessGroupNCCLOpTimeoutMillis), isHighPriorityStream(false) {} template std::shared_ptr ProcessGroupNCCL::collective( @@ -1048,17 +990,16 @@ std::shared_ptr ProcessGroupNCCL::collective( std::vector& outputs, Fn fn, PreProcess pre, - PostProcess post, - OpType opType) { + PostProcess post) { const auto devices = getDeviceList(inputs); const auto key = getKeyFromDevices(devices); - auto& ncclComms = getNCCLComm(key, devices, opType); + auto& ncclComms = getNCCLComm(key, devices); // First let NCCL streams wait for input tensors allocation streams syncStreams(devices, ncclEvents_[key], ncclStreams_[key]); // Work itself will create the CUDA events on all GPUs of tensors - auto work = initWork(devices, rank_, opType); + auto work = initWork(devices); // Store references to outputs and futureNCCLCallbackStream to be used by // WorkNCCL::getFuture. @@ -1102,13 +1043,11 @@ std::shared_ptr ProcessGroupNCCL::collective( at::cuda::CUDAStream& ncclStream = ncclStreams_[key][i]; (*work->cudaEvents_)[i].record(ncclStream); work->ncclComms_[i] = ncclComms[i]; + work->blockingWait_ = blockingWait_; + work->opTimeout_ = opTimeout_; + work->store_ = store_; } - // Set appropriate work parameters. - work->blockingWait_ = blockingWait_; - work->opTimeout_ = opTimeout_; - work->store_ = store_; - if (asyncErrorHandling_) { workEnqueue(work); } @@ -1121,21 +1060,21 @@ std::shared_ptr ProcessGroupNCCL::pointToPoint( std::vector& tensors, Fn fn, int peer, - OpType opType, + NCCLCommType commType, PreProcess pre, PostProcess post) { const auto devices = getDeviceList(tensors); const auto key = getKeySendRecv(rank_, peer); int p2pRank = rank_ < peer ? 0 : 1; - auto& ncclComms = getNCCLComm(key, devices, opType, p2pRank); + auto& ncclComms = getNCCLComm(key, devices, commType, p2pRank); // First let NCCL streams wait for input tensors allocation streams syncStreams(devices, ncclEvents_[key], ncclStreams_[key]); // Work itself will create the CUDA events on all GPUs of tensors - auto work = initWork(devices, rank_, opType); + auto work = initWork(devices); - if (opType == OpType::RECV) { + if (commType == NCCLCommType::RECV) { // Store references to outputs and futureNCCLCallbackStream to be used by // WorkNCCL::getFuture. work->outputs_ = std::make_shared>(tensors); @@ -1191,15 +1130,13 @@ template std::shared_ptr ProcessGroupNCCL::collective( std::vector& inputs, std::vector& outputs, - Fn fn, - OpType opType) { + Fn fn) { return collective( inputs, outputs, fn, [](std::vector&) {}, - [](std::vector&) {}, - opType); + [](std::vector&) {}); } template @@ -1207,12 +1144,12 @@ std::shared_ptr ProcessGroupNCCL::pointToPoint( std::vector& tensor, Fn fn, int peer, - OpType opType) { + NCCLCommType type) { return pointToPoint( tensor, fn, peer, - opType, + type, [](std::vector&) {}, [](std::vector&) {}); } @@ -1237,8 +1174,7 @@ std::shared_ptr ProcessGroupNCCL::allreduce( getNcclReduceOp(opts.reduceOp, input), comm, stream.stream()); - }, - OpType::ALLREDUCE); + }); } std::shared_ptr ProcessGroupNCCL::allreduce_coalesced( @@ -1268,8 +1204,7 @@ std::shared_ptr ProcessGroupNCCL::broadcast( root, comm, stream.stream()); - }, - OpType::BROADCAST); + }); } std::shared_ptr ProcessGroupNCCL::reduce( @@ -1294,8 +1229,7 @@ std::shared_ptr ProcessGroupNCCL::reduce( root, comm, stream.stream()); - }, - OpType::REDUCE); + }); } std::shared_ptr ProcessGroupNCCL::allgather( @@ -1338,8 +1272,7 @@ std::shared_ptr ProcessGroupNCCL::allgather( outputTensors[i][j].copy_(outputFlattened[i][j], true); } } - }, - OpType::ALLGATHER); + }); } std::shared_ptr ProcessGroupNCCL::allgather_coalesced( @@ -1391,8 +1324,7 @@ std::shared_ptr ProcessGroupNCCL::reduce_scatter( } } }, - [&](std::vector& ncclStreams) {}, - OpType::REDUCE_SCATTER); + [&](std::vector& ncclStreams) {}); } std::shared_ptr ProcessGroupNCCL::barrier( @@ -1462,8 +1394,7 @@ std::shared_ptr ProcessGroupNCCL::alltoall_base( getNcclDataType(input.scalar_type()), comm, stream.stream()); - }, - OpType::ALLTOALL_BASE); + }); } else { c10d::checkSplitSizes(inputSplitSizes, inputTensor, size_); c10d::checkSplitSizes(outputSplitSizes, outputTensor, size_); @@ -1495,8 +1426,7 @@ std::shared_ptr ProcessGroupNCCL::alltoall_base( getNcclDataType(input.scalar_type()), comm, stream.stream()); - }, - OpType::ALLTOALL_BASE); + }); } } @@ -1520,7 +1450,7 @@ std::shared_ptr ProcessGroupNCCL::send( stream.stream()); }, dstRank, - OpType::SEND); + NCCLCommType::SEND); return ret; } @@ -1529,7 +1459,7 @@ std::shared_ptr ProcessGroupNCCL::recv( int srcRank, int /* unused */) { check_gpu_tensors(tensors); - auto ret = pointToPoint( + auto ret= pointToPoint( tensors, [&](at::Tensor& output, ncclComm_t comm, @@ -1544,7 +1474,7 @@ std::shared_ptr ProcessGroupNCCL::recv( stream.stream()); }, srcRank, - OpType::RECV); + NCCLCommType::RECV); return ret; } #else diff --git a/torch/lib/c10d/ProcessGroupNCCL.hpp b/torch/lib/c10d/ProcessGroupNCCL.hpp index 13b8c72b318a..b8b3d5aabd35 100644 --- a/torch/lib/c10d/ProcessGroupNCCL.hpp +++ b/torch/lib/c10d/ProcessGroupNCCL.hpp @@ -1,6 +1,5 @@ #pragma once -#include #include #include #include @@ -24,6 +23,13 @@ constexpr const char* NCCL_BLOCKING_WAIT = "NCCL_BLOCKING_WAIT"; // Handling with NCCL. constexpr const char* NCCL_ASYNC_ERROR_HANDLING = "NCCL_ASYNC_ERROR_HANDLING"; +// NCCL Commmunication type +enum class NCCLCommType : std::uint8_t { + SEND = 0, + RECV, + COLL, +}; + // ProcessGroupNCCL implements NCCL bindings for c10d. // // All functions of the class are expected to be called in the same order @@ -65,7 +71,7 @@ class ProcessGroupNCCL : public ProcessGroup { public std::enable_shared_from_this { public: // Constructor takes a list of CUDA devices - WorkNCCL(const std::vector& devices, int rank, OpType opType); + WorkNCCL(const std::vector& devices); // Copy constructor doing partial copy without outputs_. Cleanup thread // monitors and removes finished works. However it will deadlock when // destructs outputs_ tensors who are view tensors in autograd graph. @@ -141,10 +147,6 @@ class ProcessGroupNCCL : public ProcessGroup { virtual std::exception_ptr checkForNCCLErrors( const std::vector>& ncclComms) const; - friend std::ostream& operator<<( - std::ostream& output, - const WorkNCCL& workNCCL); - private: // Helper function for synchronize void synchronizeInternal(std::chrono::milliseconds timeout); @@ -164,7 +166,6 @@ class ProcessGroupNCCL : public ProcessGroup { // Store a reference to NCCL collective's outputs to be used by getFuture. std::shared_ptr> outputs_; - // Store streams that run FutureNCCL then callbacks. std::vector> futureNCCLCallbackStreams_; @@ -466,7 +467,7 @@ class ProcessGroupNCCL : public ProcessGroup { std::vector>& getNCCLComm( const std::string& devicesKey, const std::vector& devices, - OpType opType, + NCCLCommType commType = NCCLCommType::COLL, int p2pRank = 0); // Wrapper method which can be overridden for tests. @@ -474,9 +475,7 @@ class ProcessGroupNCCL : public ProcessGroup { const std::vector>& ncclComms); virtual std::shared_ptr initWork( - std::vector devices, - int rank, - OpType opType); + std::vector devices); private: // Helper that encapsulates work shared across all collective communication @@ -489,16 +488,14 @@ class ProcessGroupNCCL : public ProcessGroup { std::shared_ptr collective( std::vector& input, std::vector& output, - Fn fn, - OpType opType); + Fn fn); template std::shared_ptr collective( std::vector& input, std::vector& output, Fn fn, PreProcess pre, - PostProcess post, - OpType opType); + PostProcess post); // Helper that encapsulates work shared across point-to-point communication // primitives. It is the same structure as the helper used for collective @@ -508,13 +505,13 @@ class ProcessGroupNCCL : public ProcessGroup { std::vector& tensor, Fn fn, int peer, - OpType opType); + NCCLCommType commType); template std::shared_ptr pointToPoint( std::vector& tensor, Fn fn, int peer, - OpType opType, + NCCLCommType commType, PreProcess pre, PostProcess post); @@ -540,8 +537,8 @@ class ProcessGroupNCCL : public ProcessGroup { // accordingly. void parseNcclBlockingWait(); - // Reads the NCCL_ASYNC_ERROR_HANDLING environment variable and sets - // asyncErrorHandling_ accordingly. + // Reads the NCCL_ASYNC_ERROR_HANDLING environment variable and sets asyncErrorHandling_ + // accordingly. void parseNcclAsyncErrorHandling(); void workCleanupLoop(); diff --git a/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp b/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp index 0df197d17cbb..93f633938e18 100644 --- a/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp +++ b/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp @@ -16,10 +16,8 @@ class WorkNCCLSimulateErrors : public c10d::ProcessGroupNCCL::WorkNCCL { public: WorkNCCLSimulateErrors( const std::vector& devices, - bool simulate_error, - int rank, - c10d::OpType opType) - : WorkNCCL(devices, rank, opType), simulate_error_(simulate_error) {} + bool simulate_error) + : WorkNCCL(devices), simulate_error_(simulate_error) {} std::exception_ptr checkForNCCLErrors( const std::vector>& ncclComms) @@ -57,11 +55,8 @@ class ProcessGroupNCCLSimulateErrors : public c10d::ProcessGroupNCCL { } std::shared_ptr initWork( - std::vector devices, - int rank, - c10d::OpType opType) override { - return std::make_shared( - devices, simulate_error_, rank, opType); + std::vector devices) override { + return std::make_shared(devices, simulate_error_); } size_t getNCCLCommCacheSize() { @@ -84,11 +79,8 @@ class WorkNCCLTimedoutErrors : public c10d::ProcessGroupNCCL::WorkNCCL { public: WorkNCCLTimedoutErrors( const std::vector& devices, - bool set_timedout_error, - int rank, - c10d::OpType opType) - : WorkNCCL(devices, rank, opType), - set_timedout_error_(set_timedout_error) {} + bool set_timedout_error) + : WorkNCCL(devices), set_timedout_error_(set_timedout_error) {} private: bool isCompleted() override { @@ -113,11 +105,9 @@ class ProcessGroupNCCLTimedOutErrors : public ProcessGroupNCCLSimulateErrors { set_timedout_error_(false) {} std::shared_ptr initWork( - std::vector devices, - int rank, - c10d::OpType opType) override { + std::vector devices) override { return std::make_shared( - devices, set_timedout_error_, rank, opType); + devices, set_timedout_error_); } void set_timedout_error() { From 72e4f51bc07b464b32a3868abe5953013bf87ba9 Mon Sep 17 00:00:00 2001 From: Elias Ellison Date: Wed, 7 Oct 2020 17:33:55 -0700 Subject: [PATCH 44/69] [JIT] fix dict update (#45857) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45857 Fix for https://github.com/pytorch/pytorch/issues/45627 Op was calling `insert` instead of `insert_or_assign`, so it wouldn't overwrite an existing key. Test Plan: Imported from OSS Reviewed By: bertmaher Differential Revision: D24148805 Pulled By: eellison fbshipit-source-id: bf39c71d5d928890b82cff1a9a0985dc47c1ffac --- test/jit/test_list_dict.py | 9 +++++++++ torch/csrc/jit/runtime/register_prim_ops.cpp | 2 +- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/test/jit/test_list_dict.py b/test/jit/test_list_dict.py index 19e4952cad57..29d2c0059395 100644 --- a/test/jit/test_list_dict.py +++ b/test/jit/test_list_dict.py @@ -1325,6 +1325,15 @@ def update(a, b): self.checkScript(update, (self.dict(), self.dict())) self.checkScript(update, (self.dict(), self.dict2())) + def test_update_existing_key(self): + def foo() -> Dict[str, int]: + a: Dict[str, int] = {} + for i in range(3): + a.update({'a': i}) + return a + + self.checkScript(foo, ()) + def test_aug_assign(self): def aug_assign_dict_tensor(a): # type: (Dict[str, Tensor]) -> Dict[str, Tensor] diff --git a/torch/csrc/jit/runtime/register_prim_ops.cpp b/torch/csrc/jit/runtime/register_prim_ops.cpp index d32dd998a040..f1cda66a52aa 100644 --- a/torch/csrc/jit/runtime/register_prim_ops.cpp +++ b/torch/csrc/jit/runtime/register_prim_ops.cpp @@ -1177,7 +1177,7 @@ void dictUpdate(Stack* stack) { auto dict = pop(stack).toGenericDict(); for (const auto& item : to_add) { - dict.insert(item.key(), item.value()); + dict.insert_or_assign(item.key(), item.value()); } } From c86655a815922b22f089b6bda1ae108ffd75e637 Mon Sep 17 00:00:00 2001 From: Elias Ellison Date: Wed, 7 Oct 2020 17:33:55 -0700 Subject: [PATCH 45/69] [JIT] Fix Dict bug in constant hashing (#45929) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45929 We were checking `and` when we should have been checking `or`. Test Plan: Imported from OSS Reviewed By: bertmaher Differential Revision: D24148804 Pulled By: eellison fbshipit-source-id: 9c394ea10ac91a588169d934b1e3208512c71b9d --- test/cpp/jit/test_constant_pooling.cpp | 22 ++++++++++++++++++++++ test/test_jit.py | 16 ++++++---------- torch/csrc/jit/ir/node_hashing.cpp | 2 +- 3 files changed, 29 insertions(+), 11 deletions(-) diff --git a/test/cpp/jit/test_constant_pooling.cpp b/test/cpp/jit/test_constant_pooling.cpp index 8479c96742b0..6f81e5db907b 100644 --- a/test/cpp/jit/test_constant_pooling.cpp +++ b/test/cpp/jit/test_constant_pooling.cpp @@ -88,5 +88,27 @@ graph(): /*exactly*/ true) ->run(*graph); } + +TEST(ConstantPoolingTest, DictConstantPooling) { + auto graph = std::make_shared(); + parseIR( + R"IR( +graph(): + %0 : int = prim::Constant[value=1]() # test/elias.py:6:9 + %1 : int = prim::Constant[value=2]() # test/elias.py:6:12 + %a.1 : Dict(int, int) = prim::DictConstruct(%0, %1) + %b.1 : Dict(int, int) = prim::DictConstruct(%1, %1) + return (%a.1, %b.1) + )IR", + &*graph); + ConstantPropagation(graph); + ConstantPooling(graph); + testing::FileCheck() + .check_count( + "Dict(int, int) = prim::Constant", + 2, + /*exactly*/ true) + ->run(*graph); +} } // namespace jit } // namespace torch diff --git a/test/test_jit.py b/test/test_jit.py index 5baa240e30b8..797904d2bf20 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -76,7 +76,7 @@ from collections import defaultdict, namedtuple, OrderedDict import copy from copy import deepcopy -from itertools import product, chain +from itertools import product import itertools from textwrap import dedent from typing import List, Dict, Optional, Tuple, Union @@ -1969,15 +1969,17 @@ def check_constant(constant_constructor): tup_constant = constants[i] + ", " + constants[j] check_constant(tup_constant) + dict_constants = [] for i in range(len(constants)): # check_constant constructs the second dict with another Tensor # which fails the comparison - if isinstance(eval(constants[i]), (list, bool, Tensor)) or eval(constants[i]) is None: + if not isinstance(eval(constants[i]), (str, int, float)): continue for j in range(len(constants)): dict_constant = "{ " + constants[i] + ": " + constants[j] + "}" check_constant(dict_constant) - + dict_constants.append(dict_constant) + constants = constants + dict_constants # testing node hashing funcs_template = dedent(''' @@ -2009,14 +2011,8 @@ def func(): # generate dicts with built-in types (excluding torch.Tensor) xprod = itertools.product(constants, constants) - def keys_pred(t): - return isinstance(eval(t[0]), (list, bool)) or eval(t[0]) is None - - filt = [x for x in xprod if not keys_pred(x)] - dict_strs = map(lambda t: '{' + t[0] + ':' + t[1] + '}', filt) - # test that equal tuples and dicts correctly work with node hashing - for tup in chain(map(lambda x: "(" + x + ",)", constants), dict_strs): + for tup in map(lambda x: "(" + x + ",)", constants): funcs_str = funcs_template.format(constant_constructor=tup) scope = {} execWrapper(funcs_str, globals(), scope) diff --git a/torch/csrc/jit/ir/node_hashing.cpp b/torch/csrc/jit/ir/node_hashing.cpp index 52cace15075f..c690af12068d 100644 --- a/torch/csrc/jit/ir/node_hashing.cpp +++ b/torch/csrc/jit/ir/node_hashing.cpp @@ -114,7 +114,7 @@ bool ivaluesEqual(const IValue& a1, const IValue& a2) { const auto& e_a1 = *it_a1; const auto& e_a2 = *it_a2; - if (!ivaluesEqual(e_a1.key(), e_a2.key()) && + if (!ivaluesEqual(e_a1.key(), e_a2.key()) || !ivaluesEqual(e_a1.value(), e_a2.value())) { return false; } From a36f11a3a58e9fe026297fe2f0822f8fa706c037 Mon Sep 17 00:00:00 2001 From: Venkata Chintapalli Date: Wed, 7 Oct 2020 17:41:03 -0700 Subject: [PATCH 46/69] [FakeLowP] T76913842 Make AddFakeFp16 take int inputs (#45992) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45992 Created a template version of AddFakeFp16 to take both float and int inputs. Test Plan: notebook with local bento kernel: N369049 Reviewed By: amylittleyang Differential Revision: D24169720 fbshipit-source-id: 679de391224f65f6c5b3ca890eb0d157f09712f6 --- .../contrib/fakelowp/elementwise_fp16_fake_op.cc | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/caffe2/contrib/fakelowp/elementwise_fp16_fake_op.cc b/caffe2/contrib/fakelowp/elementwise_fp16_fake_op.cc index 7debb5f7cf7e..c61668868178 100644 --- a/caffe2/contrib/fakelowp/elementwise_fp16_fake_op.cc +++ b/caffe2/contrib/fakelowp/elementwise_fp16_fake_op.cc @@ -22,7 +22,21 @@ int getSizeFromDims(const std::vector& dims) { template struct FP16PairWiseCPUFunctor : public OP { + template bool Forward( + const std::vector& A_dims, + const std::vector& B_dims, + const TIn* A, + const TIn* B, + TOut* C, + CPUContext* context) const { + OP::Forward(A_dims, B_dims, A, B, C, context); + + return true; + } + + template<> + bool Forward( const std::vector& A_dims, const std::vector& B_dims, const float* A, @@ -54,7 +68,7 @@ OPERATOR_SCHEMA(SumFakeFp16).NumInputs(1, INT_MAX).NumOutputs(1, INT_MAX); REGISTER_CPU_OPERATOR( AddFakeFp16, BinaryElementwiseOp< - TensorTypes, + TensorTypes, CPUContext, FP16PairWiseCPUFunctor>>); OPERATOR_SCHEMA(AddFakeFp16).NumInputs(2).NumOutputs(1); From 19da1d22fe337e19a98d4cbc711b888e1c90b0ee Mon Sep 17 00:00:00 2001 From: Nick Gibson Date: Wed, 7 Oct 2020 18:11:16 -0700 Subject: [PATCH 47/69] [NNC] Registerizer V2, supporting partial and conditional replacement (#45574) Summary: This is a rewrite of the Registerizer, supporting scalar replacement in *vastly* more situations. As a refresher, the registerizer does this: Before: ``` A[0] = 0; for (int x = 0; x < 10; x++) { A[0] = (A[0]) + x; } ``` After: ``` int A_ = 0; for (int x = 0; x < 10; x++) { A_ = x + A_; } A[0] = A_; ``` Which can greatly reduce the number of accesses to main memory in a kernel. There are cases where doing this gets complicated, and the existing implementation bails out whenever encountering multiple partial overlaps of the same buffer, or conditional accesses under any circumstances. This makes it much less useful in the presence of complex (ie. real world not example) kernels. This new version should work optimally in almost all cases (I have a few minor follow ups). I tested this version extensively, and found quite a few bugs in the original implementation I'd prefer not to back port fixes for - so I'm in favor of landing this even if we don't immediately see a perf win. I believe the killer app for this kind of optimization is fused reductions and we haven't enabled many examples of that yet. It is safe to move two accesses of the same Tensor element to a local scalar Var if between all usages of the element there are no other Loads or Stores that may refer to it. In the comments I refer to this as overlapping the access, or "cutting" the existing AccessInfo. In the case where a candidate for registerization is cut, it may be possible to finalize the access early by writing it back to the Tensor and then create a new scalar variable after the overlapping access is complete. We will attempt to do this when it saves memory accesses. There are a few cases that make this more challenging: - For: Loops change the number of real usages of a buffer by the loop extent, but only if we can pull the definition and finalization of the scalar variable out of the loop block. For loops often create accesses which are conditional on a loop var and will overlap large ranges of elements. E.g. Before: ``` A[0] = 2; for (int x1 = 0; x1 < 10; x1++) { A[0] = (A[0]) + x1; } for (int x2 = 1; x2 < 10; x2++) { A[x2] = A[x2 - 1]; } for (int x3 = 0; x3 < 10; x3++) { A[0] = (A[0]) + x3; } ``` After: ``` int A_1 = 2; for (int x1 = 0; x1 < 10; x1++) { A_1 = A_1 + x1; } A[0] = A_1; for (int x2 = 1; x2 < 10; x2++) { A[x2] = A[x2 - 1]; } int A_2 = A[0]; for (int x3 = 0; x3 < 10; x3++) { A_2 = A_2 + x3; } A[0] = A_2; ``` - Cond: Conditions complicate lifting scalars out of internal scopes. Generally we cannot lift an access outside of a conditional scope unless there is already a reference to that same access at the higher scope, since we don't know if the condition was guarding an array access not safe at the higher scope. In the comments I refer to this as the condition "hiding" the access, and the outer access "unhiding" it. E.g. this example: ``` if (x<5 ? 1 : 0) { A[x] = (A[x]) + 1; } A[x] = (A[x]) + 1; if (x>5 ? 1 : 0) { A[x] = (A[x]) + 1; } ``` The A[x] access can be registerized due to the unconditional access between the two conditions: ``` int A_1 = A[x]; if (x<5 ? 1 : 0) { A_1 = A_1 + 1; } A_1 = A_1 + 1; if (x>5 ? 1 : 0) { A_1 = A_1 + 1; } A[x] = A_1; ``` But this example has no accesses that can be registerized: ``` if (x<5 ? 1 : 0) { A[x] = (A[x]) + 1; } if (x>5 ? 1 : 0) { A[x] = (A[x]) + 1; } ``` - IfThenElse: Same situation as Cond, except since IfThenElse is an Expr rather than a Stmt we cannot insert the scalar definition or finalizer within the conditional scope. Accesses inside an IfThenElse can be safely combined with external accesses but cannot exist completely within. E.g in this example the `B[x]` cannot be registerized as there is no safe place to define it. ``` A[x] = IfThenElse(x<3 ? 1 : 0, (B[x]) + (B[x]), B[x]); ``` But the equivalent kernel using Cond can be registerized: ``` if (x<3 ? 1 : 0) { float B_1 = B[x]; A[x] = B_1 + B_1; } else { A[x] = B[x]; } ``` - Let: Accesses dependent on local variables via Let Stmts, or loop vars, cannot be raised outside of the scope of the dependent var. E.g. no accesses in this example can be registerized: ``` for (int x = 0; x < 10; x++) { int y = 30; A[y] = x + (A[y]); } ``` But they can in this example: ``` int y = 30; for (int x = 0; x < 10; x++) { A[y] = x + (A[y]); } ``` **Testing** The majority of this PR is tests, over 3k lines of them, because there are many different rules to consider and they can interact together more or less arbitrarily. I'd greatly appreciate any ideas for situations we could encounter that are not covered by the tests. **Performance** Still working on it, will update. In many FastRRNS sub kernels this diff reduces the number of total calls to Store or Load by 4x, but since those kernels use Concat very heavily (meaning a lot of branches) the actual number encountered by any particular thread on GPU is reduced only slightly. Overall perf improved by a very small amount. Reductions is where this optimization should really shine, and in particular the more complex the kernel gets (with extra fusions, etc) the better this version of the registerizer should do compared the existing version. Pull Request resolved: https://github.com/pytorch/pytorch/pull/45574 Reviewed By: albanD Differential Revision: D24151517 Pulled By: nickgg fbshipit-source-id: 9f0b2d98cc213eeea3fda16fee3d144d49fd79ae --- test/cpp/tensorexpr/test_cuda.cpp | 6 +- test/cpp/tensorexpr/test_registerizer.cpp | 3346 ++++++++++++++++++- test/cpp/tensorexpr/tests.h | 632 ++-- torch/csrc/jit/tensorexpr/ir_simplifier.cpp | 7 +- torch/csrc/jit/tensorexpr/registerizer.cpp | 883 +++-- torch/csrc/jit/tensorexpr/registerizer.h | 371 +- torch/csrc/jit/tensorexpr/stmt.h | 8 + 7 files changed, 4506 insertions(+), 747 deletions(-) diff --git a/test/cpp/tensorexpr/test_cuda.cpp b/test/cpp/tensorexpr/test_cuda.cpp index 6dba8c574c57..df5359da83a9 100644 --- a/test/cpp/tensorexpr/test_cuda.cpp +++ b/test/cpp/tensorexpr/test_cuda.cpp @@ -769,14 +769,14 @@ void testCudaSharedMemReduce_1() { // Check the c write is not masked, but the d write is. const std::string& verification_pattern = R"IR( -# CHECK: c_ = 0 +# CHECK: c_1 = 0 # CHECK: for (int m = 0; m < 128 -# CHECK: c_ = c_ + +# CHECK: c_1 = c_1 + # CHECK: __syncthreads(); # CHECK: if (threadIdx.x<1 # CHECK: b[blockIdx.x] = # CHECK: __syncthreads(); -# CHECK: atomicAdd(&b[blockIdx.x], c_) +# CHECK: atomicAdd(&b[blockIdx.x], c_1) )IR"; torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); diff --git a/test/cpp/tensorexpr/test_registerizer.cpp b/test/cpp/tensorexpr/test_registerizer.cpp index b286ab7b8151..91360e9ff8d3 100644 --- a/test/cpp/tensorexpr/test_registerizer.cpp +++ b/test/cpp/tensorexpr/test_registerizer.cpp @@ -31,14 +31,14 @@ void testRegisterizerSimple() { * } */ - registerize(stmt); + stmt = registerize(stmt); /* - * int A_ = 0; + * int A_1 = 0; * for (int x = 0; x < 10; x++) { - * A_ = x + A_; + * A_1 = x + A_1; * } - * A[0] = A_; + * A[0] = A_1; */ std::ostringstream oss; @@ -46,11 +46,11 @@ void testRegisterizerSimple() { const std::string& verification_pattern = R"IR( -# CHECK: int A_ = 0; +# CHECK: int A_1 = 0; # CHECK: for (int x = 0; x < 10; x++) # CHECK-NOT: A[ -# CHECK: A_ = -# CHECK: A[0] = A_;)IR"; +# CHECK: A_1 = +# CHECK: A[0] = A_1;)IR"; torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); } @@ -77,7 +77,7 @@ void testRegisterizerLoop() { */ // No change. - registerize(stmt); + stmt = registerize(stmt); /* * A[0] = 0; @@ -96,7 +96,7 @@ void testRegisterizerLoop() { # CHECK: for (int x = 0; x < 10; x++) # CHECK-NOT: A_ # CHECK: A[x] = -# CHECK-NOT: A[0] = A_;)IR"; +# CHECK-NOT: A[0] = A_1;)IR"; torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); } @@ -119,17 +119,17 @@ void testRegisterizerLoopFixedLoad() { /* * A[0] = 0; * for (int x = 0; x < 10; x++) { - * A[0] = (A[0]) + x; + * A[x] = (A[0]) + x; * } */ // No change. - registerize(stmt); + stmt = registerize(stmt); /* * A[0] = 0; * for (int x = 0; x < 10; x++) { - * A[0] = (A[0]) + x; + * A[x] = (A[0]) + x; * } */ @@ -143,11 +143,258 @@ void testRegisterizerLoopFixedLoad() { # CHECK: for (int x = 0; x < 10; x++) # CHECK-NOT: A_ # CHECK: A[x] = -# CHECK-NOT: A[0] = A_;)IR"; +# CHECK-NOT: A[0] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// We can registerize accesses that occur entirely within inner scopes, even if +// they depend on the loop var. +void testRegisterizerLoopInternal() { + KernelScope kernel_scope; + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + Stmt* stmt = Block::make({For::make( + x, + 0, + 10, + Block::make( + {Store::make(a, {x}, Add::make(Load::make(a, {x}, 1), x), 1), + Store::make(a, {x}, Add::make(Load::make(a, {x}, 1), x), 1)}))}); + + /* + * for (int x = 0; x < 10; x++) { + * A[x] = (A[x]) + x; + * A[x] = (A[x]) + x; + * } + */ + + stmt = registerize(stmt); + + /* + * for (int x = 0; x < 10; x++) { + * int A_1 = A[x]; + * A_1 = A_1 + x; + * A_1 = A_1 + x; + * A[x] = A_1; + * } + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: for (int x = 0; x < 10; x++) +# CHECK: int A_1 = A[x]; +# CHECK: A_1 = A_1 + x; +# CHECK: A_1 = A_1 + x; +# CHECK: A[x] = A_1; +# CHECK: })IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// An access can be overlapped by another read in the same Expr. In this case +// B[z] and B[y] overlap and prevent registerization of both accesses. +void testRegisterizerLoopInternalLoadOverlap() { + KernelScope kernel_scope; + BufHandle a("A", {10}, kInt); + BufHandle b("B", {10}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + VarHandle z("z", kInt); + Stmt* stmt = Block::make({For::make( + x, + 0, + 10, + Store::make( + a, + {x}, + Add::make(Load::make(b, {y}, 1), Load::make(b, {z}, 1)), + 1))}); + + /* + * for (int x = 0; x < 10; x++) { + * A[x] = (B[y]) + (B[z]); + * } + */ + + std::ostringstream before; + before << *stmt; + + // No change. + stmt = registerize(stmt); + + std::ostringstream after; + after << *stmt; + + ASSERT_EQ(before.str(), after.str()); +} + +void testRegisterizerLoopInternalRepeated() { + KernelScope kernel_scope; + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + Stmt* stmt = Block::make( + {For::make( + x, + 0, + 10, + Block::make( + {Store::make(a, {0}, Add::make(Load::make(a, {1}, 1), x), 1), + Store::make(a, {0}, Add::make(Load::make(a, {1}, 1), x), 1)})), + For::make( + x, + 0, + 10, + Block::make( + {Store::make(a, {0}, Add::make(Load::make(a, {1}, 1), x), 1), + Store::make(a, {0}, Add::make(Load::make(a, {1}, 1), x), 1)})) + + }); + + /* + * for (int x = 0; x < 10; x++) { + * A[0] = x + (A[1]); + * A[0] = x + (A[1]); + * } + * for (int x = 0; x < 10; x++) { + * A[0] = x + (A[1]); + * A[0] = x + (A[1]); + * } + */ + + stmt = registerize(stmt); + + /* + * int A_1 = A[1]; + * int A_2 = A[0]; + * for (int x = 0; x < 10; x++) { + * A_2 = x + A_1; + * A_2 = x + A_1; + * } + * for (int x = 0; x < 10; x++) { + * A_2 = x + A_1; + * A_2 = x + A_1; + * } + * A[0] = A_2; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = A[1]; +# CHECK: int A_2 = A[0]; +# CHECK: for (int x = 0; x < 10; x++) +# CHECK: A_2 = x + A_1; +# CHECK: A_2 = x + A_1; +# CHECK: } +# CHECK: for (int x = 0; x < 10; x++) +# CHECK: A_2 = x + A_1; +# CHECK: A_2 = x + A_1; +# CHECK: } +# CHECK-NOT: A[1] +# CHECK: A[0] = A_2; +# CHECK-NOT: A[1] +# CHECK: })IR"; torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); } +void testRegisterizerLoopInternalRepeatedOverlapLoopVar() { + KernelScope kernel_scope; + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + Stmt* stmt = Block::make( + {For::make( + x, + 0, + 10, + Block::make( + {Store::make(a, {0}, Add::make(Load::make(a, {x}, 1), x), 1), + Store::make(a, {0}, Add::make(Load::make(a, {x}, 1), x), 1)})), + For::make( + x, + 0, + 10, + Block::make( + {Store::make(a, {0}, Add::make(Load::make(a, {x}, 1), x), 1), + Store::make(a, {0}, Add::make(Load::make(a, {x}, 1), x), 1)})) + + }); + + /* + * for (int x = 0; x < 10; x++) { + * A[0] = (A[x]) + x; + * A[0] = (A[x]) + x; + * } + * for (int x = 0; x < 10; x++) { + * A[0] = (A[x]) + x; + * A[0] = (A[x]) + x; + * } + */ + + std::ostringstream before; + before << *stmt; + + // No change. + stmt = registerize(stmt); + + std::ostringstream after; + after << *stmt; + + ASSERT_EQ(before.str(), after.str()); +} + +void testRegisterizerLoopInternalRepeatedOverlapOther() { + KernelScope kernel_scope; + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + Stmt* stmt = Block::make( + {For::make( + x, + 0, + 10, + Block::make( + {Store::make(a, {0}, Add::make(x, Load::make(a, {y}, 1)), 1), + Store::make(a, {0}, Add::make(x, Load::make(a, {y}, 1)), 1)})), + For::make( + x, + 0, + 10, + Block::make( + {Store::make(a, {0}, Add::make(x, Load::make(a, {y}, 1)), 1), + Store::make(a, {0}, Add::make(x, Load::make(a, {y}, 1)), 1)})) + + }); + + /* + * for (int x = 0; x < 10; x++) { + * A[0] = (A[x]) + x; + * A[0] = (A[x]) + x; + * } + * for (int x = 0; x < 10; x++) { + * A[0] = (A[x]) + x; + * A[0] = (A[x]) + x; + * } + */ + + std::ostringstream before; + before << *stmt; + + // No change. + stmt = registerize(stmt); + + std::ostringstream after; + after << *stmt; + + ASSERT_EQ(before.str(), after.str()); +} + // Will registerize multiple accesses of different items of the same buffer. void testRegisterizerMultiVar() { KernelScope kernel_scope; @@ -174,17 +421,17 @@ void testRegisterizerMultiVar() { * } */ - registerize(stmt); + stmt = registerize(stmt); /* - * int A_ = 0; - * int A__1 = 0; + * int A_1 = 0; + * int A_2 = 0; * for (int x = 0; x < 10; x++) { - * A__1 = x + A__1; - * A_ = A_ - x; + * A_2 = x + A_2; + * A_1 = A_1 - x; * } - * A[1] = A__1; - * A[0] = A_; + * A[1] = A_2; + * A[0] = A_1; */ std::ostringstream oss; @@ -192,14 +439,14 @@ void testRegisterizerMultiVar() { const std::string& verification_pattern = R"IR( -# CHECK: int A_ = 0; -# CHECK: int A__1 = 0; +# CHECK: int A_1 = 0; +# CHECK: int A_2 = 0; # CHECK: for (int x = 0; x < 10; x++) # CHECK-NOT: A[ -# CHECK: A_ = -# CHECK: A__1 = -# CHECK: A[1] = A__1 -# CHECK: A[0] = A_;)IR"; +# CHECK: A_1 = +# CHECK: A_2 = +# CHECK: A[1] = A_2 +# CHECK: A[0] = A_1;)IR"; torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); } @@ -234,17 +481,17 @@ void testRegisterizerVariableLoad() { * } */ - registerize(stmt); + stmt = registerize(stmt); /* - * int A_ = 0; + * int A_1 = 0; * for (int x = 0; x < 10; x++) { * B[x] = x; * } * for (int x_1 = 0; x_1 < 10; x_1++) { - * A_ = A_ + (B[x_1]); + * A_1 = A_1 + (B[x_1]); * } - * A[0] = A_; + * A[0] = A_1; */ std::ostringstream oss; @@ -252,13 +499,13 @@ void testRegisterizerVariableLoad() { const std::string& verification_pattern = R"IR( -# CHECK: int A_ = 0; +# CHECK: int A_1 = 0; # CHECK: for (int x = 0; x < 10; x++) # CHECK: B[x] = x # CHECK: for (int x_1 = 0; x_1 < 10; x_1++) # CHECK-NOT: A[ -# CHECK: A_ = -# CHECK: A[0] = A_;)IR"; +# CHECK: A_1 = +# CHECK: A[0] = A_1;)IR"; torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); } @@ -286,14 +533,14 @@ void testRegisterizerSymbolicIndices() { * } */ - registerize(stmt); + stmt = registerize(stmt); /* - * int A_ = 0; + * int A_1 = 0; * for (int x = 0; x < 10; x++) { - * A_ = x + A_; + * A_1 = x + A_1; * } - * A[i] = A_; + * A[i] = A_1; */ std::ostringstream oss; @@ -301,46 +548,15 @@ void testRegisterizerSymbolicIndices() { const std::string& verification_pattern = R"IR( -# CHECK: int A_ = 0; +# CHECK: int A_1 = 0; # CHECK: for (int x = 0; x < 10; x++) # CHECK-NOT: A[ -# CHECK: A_ = -# CHECK: A[i] = A_;)IR"; +# CHECK: A_1 = +# CHECK: A[i] = A_1;)IR"; torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); } -// Will not registerize if a variable usage of the sclar may overlap the target -// scalar. -// TODO: we can support this by writing back to the buffer before the variable -// access, but we'd need temporal analysis of dependencies which we don't have -// yet. Will have to fix soon though. -void testRegisterizerEarlyStop() { - KernelScope kernel_scope; - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - Stmt* stmt = Block::make( - {Store::make(a, {0}, 0, 1), - For::make( - x, - 0, - 10, - Block::make( - {Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), x), 1)})), - For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1}, 1), 1))}); - - std::ostringstream before; - before << *stmt; - - // No change. - registerize(stmt); - - std::ostringstream after; - after << *stmt; - - ASSERT_EQ(before.str(), after.str()); -} - // Can registerize accesses dependent on multiple loop vars. void testRegisterizerMultiLoop() { KernelScope kernel_scope; @@ -372,16 +588,16 @@ void testRegisterizerMultiLoop() { * } */ - registerize(stmt); + stmt = registerize(stmt); /* - * int A_ = 0; + * int A_1 = 0; * for (int x = 0; x < 10; x++) { * for (int y = 0; y < 10; y++) { - * A_ = x * y + y * A_l + * A_1 = x * y + y * A_1; * } * } - * A[0] = A_; + * A[0] = A_1; */ std::ostringstream oss; @@ -389,12 +605,12 @@ void testRegisterizerMultiLoop() { const std::string& verification_pattern = R"IR( -# CHECK: int A_ = 0; +# CHECK: int A_1 = 0; # CHECK: for (int x = 0; x < 10; x++) # CHECK: for (int y = 0; y < 10; y++) # CHECK-NOT: A[ -# CHECK: A_ = -# CHECK: A[0] = A_;)IR"; +# CHECK: A_1 = +# CHECK: A[0] = A_1;)IR"; torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); } @@ -418,23 +634,24 @@ void testRegisterizerRepeated() { // Registerize manually to make sure we only replace a single target. { - RegisterizerAnalysis analysis; + registerizer::RegisterizerAnalysis analysis; stmt->accept(&analysis); auto candidates = analysis.getCandidates(); ASSERT_EQ(candidates.size(), 2); - RegisterizerReplacer replacer(candidates.front()); + candidates.pop_back(); + registerizer::RegisterizerReplacer replacer(candidates); stmt = stmt->accept_mutator(&replacer); } // Re-analyze and replace the second target. { - RegisterizerAnalysis analysis; + registerizer::RegisterizerAnalysis analysis; stmt->accept(&analysis); auto candidates = analysis.getCandidates(); ASSERT_EQ(candidates.size(), 1); - RegisterizerReplacer replacer(candidates.front()); + registerizer::RegisterizerReplacer replacer(candidates); stmt = stmt->accept_mutator(&replacer); } @@ -443,19 +660,19 @@ void testRegisterizerRepeated() { const std::string& verification_pattern = R"IR( -# CHECK: int A_ = 0; -# CHECK: int A__1 = 0; +# CHECK: int A_1 = 0; +# CHECK: int A_1_1 = 0; # CHECK: for (int x = 0; x < 10; x++) # CHECK-NOT: A[ -# CHECK: A_ = -# CHECK: A__1 = -# CHECK: A[1] = A__1 -# CHECK: A[0] = A_;)IR"; +# CHECK: A_1 = +# CHECK: A_1_1 = +# CHECK: A[1] = A_1_1; +# CHECK: A[0] = A_1;)IR"; torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); } -// Can registerize rthe load of A. +// Can registerize the load of A. void testRegisterizerNoLoads() { KernelScope kernel_scope; BufHandle a("A", {1}, kInt); @@ -472,14 +689,14 @@ void testRegisterizerNoLoads() { * } */ - registerize(stmt); + stmt = registerize(stmt); /* - * int A_ = 0; + * int A_1 = 0; * for (int x = 0; x < 10; x++) { - * A_ = x + 1; + * A_1 = x + 1; * } - * A[0] = A_; + * A[0] = A_1; */ std::ostringstream oss; @@ -487,11 +704,11 @@ void testRegisterizerNoLoads() { const std::string& verification_pattern = R"IR( -# CHECK: int A_ = 0; +# CHECK: int A_1 = 0; # CHECK: for (int x = 0; x < 10; x++) # CHECK-NOT: A[ -# CHECK: A_ = -# CHECK: A[0] = A_;)IR"; +# CHECK: A_1 = +# CHECK: A[0] = A_1;)IR"; torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); } @@ -518,17 +735,17 @@ void testRegisterizerNoRepeatedStores() { * } */ - registerize(stmt); + stmt = registerize(stmt); // TODO: its unnecessary to reorder the initializer of A[0], but it's not // actually worse so lets not worry for now. /* - * int A_ = 0; + * int A_1 = 0; * for (int x = 0; x < 10; x++) { - * B[x] = x + A_; + * B[x] = x + A_1; * } - * A[0] = A_; + * A[0] = A_1; */ std::ostringstream oss; @@ -536,11 +753,11 @@ void testRegisterizerNoRepeatedStores() { const std::string& verification_pattern = R"IR( -# CHECK: int A_ = 0; +# CHECK: int A_1 = 0; # CHECK: for (int x = 0; x < 10; x++) # CHECK-NOT: A_ # CHECK: B[x] = -# CHECK: A[0] = A_;)IR"; +# CHECK: A[0] = A_1;)IR"; torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); } @@ -567,7 +784,7 @@ void testRegisterizerMultiVarOverlap() { before << *stmt; // No change. - registerize(stmt); + stmt = registerize(stmt); std::ostringstream after; after << *stmt; @@ -609,19 +826,19 @@ void testRegisterizerAllocs() { * Free(B); */ - registerize(stmt); + stmt = registerize(stmt); /* - * int C_ = C[0]; + * int C_1 = C[0]; * Allocate(B, int, {C_}); - * int A_ = C_; - * int B_ = 0; + * int A_1 = C_1; + * int B_1 = 0; * for (int x = 0; x < 10; x++) { - * B_ = B_ + x; - * A_ = C_; + * B_1 = B_1 + x; + * A_1 = C_1; * } - * B[0] = B_; - * A[0] = A_; + * B[0] = B_1; + * A[0] = A_1; * Free(B); */ @@ -630,15 +847,15 @@ void testRegisterizerAllocs() { const std::string& verification_pattern = R"IR( -# CHECK: int C_ = C[0]; +# CHECK: int C_1 = C[0]; # CHECK: Allocate(B -# CHECK: int A_ = C_; -# CHECK: int B_ = 0; +# CHECK: int A_1 = C_1; +# CHECK: int B_1 = 0; # CHECK: for (int x = 0; x < 10; x++) -# CHECK: B_ = -# CHECK: A_ = C_ -# CHECK: B[0] = B_; -# CHECK: A[0] = A_; +# CHECK: B_1 = +# CHECK: A_1 = C_ +# CHECK: B[0] = B_1; +# CHECK: A[0] = A_1; # CHECK: Free(B)IR"; torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); @@ -661,14 +878,14 @@ void testRegisterizerNoInitializer() { * } */ - registerize(stmt); + stmt = registerize(stmt); /* - * int A_ = A[0]; + * int A_1 = A[0]; * for (int x = 0; x < 10; x++) { - * A_ = x + A_; + * A_1 = x + A_1; * } - * A[0] = A_; + * A[0] = A_1; */ std::ostringstream oss; @@ -676,15 +893,44 @@ void testRegisterizerNoInitializer() { const std::string& verification_pattern = R"IR( -# CHECK: int A_ = A[0]; +# CHECK: int A_1 = A[0]; # CHECK: for (int x = 0; x < 10; x++) # CHECK-NOT: A[ -# CHECK: A_ = -# CHECK: A[0] = A_;)IR"; +# CHECK: A_1 = +# CHECK: A[0] = A_1;)IR"; torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); } +void testRegisterizerNoInitializerLoopVar() { + KernelScope kernel_scope; + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + Stmt* stmt = Block::make({For::make( + x, + 0, + 10, + Block::make( + {Store::make(a, {x}, Add::make(Load::make(a, {x}, 1), x), 1)}))}); + + /* + * for (int x = 0; x < 10; x++) { + * A[x] = (A[x]) + x; + * } + */ + + std::ostringstream before; + before << *stmt; + + // No change. + stmt = registerize(stmt); + + std::ostringstream after; + after << *stmt; + + ASSERT_EQ(before.str(), after.str()); +} + void testRegisterizerLoadThenStore() { KernelScope kernel_scope; BufHandle a("A", {1}, kInt); @@ -704,17 +950,17 @@ void testRegisterizerLoadThenStore() { * } */ - registerize(stmt); + stmt = registerize(stmt); /* - * int A_ = A[0]; - * int B_ = B[0]; + * int A_1 = A[0]; + * int B_1 = B[0]; * for (int x = 0; x < 10; x++) { - * B_ = x + A_; - * A_ = B_; + * B_1 = x + A_1; + * A_1 = B_1; * } - * B[0] = B_; - * A[0] = A_; + * B[0] = B_1; + * A[0] = A_1; */ std::ostringstream oss; @@ -722,15 +968,15 @@ void testRegisterizerLoadThenStore() { const std::string& verification_pattern = R"IR( -# CHECK: int A_ = A[0]; -# CHECK: int B_ = B[0]; +# CHECK: int A_1 = A[0]; +# CHECK: int B_1 = B[0]; # CHECK: for (int x = 0; x < 10; x++) # CHECK-NOT: B[ -# CHECK: B_ = +# CHECK: B_1 = # CHECK-NOT: A[ -# CHECK: A_ = B_ +# CHECK: A_1 = B_ # CHECK: B[0] = B_ -# CHECK: A[0] = A_;)IR"; +# CHECK: A[0] = A_1;)IR"; torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); } @@ -763,55 +1009,2805 @@ void testRegisterizerParallelized() { "Registerization must occur after parallelism flattening"); } -void testRegisterizerConditions() { +// Should be able to registerize this since the scalar would exist before the +// branch. +void testRegisterizerConditionAfter() { KernelScope kernel_scope; BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + BufHandle c("C", {5}, kInt); VarHandle x("x", kInt); - Stmt* stmt = Block::make({For::make( - x, - 0, - 10, - Block::make({ - Cond::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Store::make( - a, - {x}, - IfThenElse::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Add::make(Load::make(a, {x}, 1), x), - Add::make(Load::make(a, {x - 5}, 1), x)), - 1), - Store::make( - a, - {x - 5}, - IfThenElse::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Add::make(Load::make(a, {x}, 1), x), - Add::make(Load::make(a, {x - 5}, 1), x)), - 1)), - }))}); - std::ostringstream before; - before << *stmt; + Stmt* stmt = Block::make( + {Store::make(a, {x}, Load::make(b, {x}, 1), 1), + Store::make(c, {x}, Load::make(a, {x}, 1), 1), + Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Store::make(a, {x}, Add::make(Load::make(a, {x}, 1), 1), 1), + nullptr)}); - /* for (int x = 0; x < 10; x++) { - * if (x<5 ? 1 : 0) { - * A[x] = IfThenElse(x<5 ? 1 : 0, (A[x]) + x, (A[x - 5]) + x); - * } else { - * A[x - 5] = IfThenElse(x<5 ? 1 : 0, (A[x]) + x, (A[x - 5]) + x); - * } + /* + * A[x] = B[x]; + * C[x] = A[x]; + * if (x<5 ? 1 : 0) { + * A[x] = (A[x]) + 1; * } */ - // No change. - registerize(stmt); - - std::ostringstream after; - after << *stmt; + stmt = registerize(stmt); + + /* + * int A_1 = B[x]; + * C[x] = A_1; + * if (x<5 ? 1 : 0) { + * A_1 = A_1 + 1; + * } + * A[x] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = B[x]; +# CHECK: C[x] = A_1; +# CHECK: if ( +# CHECK: A_1 = A_1 + 1; +# CHECK: A[x] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Should be able to registerize this since the scalar exists in the same form +// after the branch and there is no overlap. +void testRegisterizerConditionBefore() { + KernelScope kernel_scope; + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + BufHandle c("C", {5}, kInt); + VarHandle x("x", kInt); + + Stmt* stmt = Block::make( + {Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Store::make(a, {x}, Add::make(Load::make(a, {x}, 1), 1), 1), + nullptr), + Store::make(a, {x}, Load::make(b, {x}, 1), 1), + Store::make(c, {x}, Load::make(a, {x}, 1), 1)}); + + /* + * if (x<5 ? 1 : 0) { + * A[x] = (A[x]) + 1; + * } + * A[x] = B[x]; + * C[x] = A[x]; + */ + + stmt = registerize(stmt); + + /* + * int A_ 1 = A[x]; + * if (x<5 ? 1 : 0) { + * A_1 = A_1 + 1; + * } + * A_1 = B[x]; + * C[x] = A_1; + * A[x] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = A[x]; +# CHECK: if ( +# CHECK: A_1 = A_1 + 1; +# CHECK: } +# CHECK: A_1 = B[x]; +# CHECK: C[x] = A_1; +# CHECK: A[x] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Should be able to registerize this as the combination of the two above rules. +void testRegisterizerConditionInside() { + KernelScope kernel_scope; + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + BufHandle c("C", {5}, kInt); + VarHandle x("x", kInt); + + Stmt* stmt = Block::make( + {Store::make(a, {x}, Load::make(b, {x}, 1), 1), + Store::make(c, {x}, Load::make(a, {x}, 1), 1), + Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Store::make(a, {x}, Add::make(Load::make(a, {x}, 1), 1), 1), + nullptr), + Store::make(b, {x}, Load::make(a, {x}, 1), 1), + Store::make(a, {x}, Load::make(c, {x}, 1), 1)}); + + /* + * A[x] = B[x]; + * C[x] = A[x]; + * if (x<5 ? 1 : 0) { + * A[x] = (A[x]) + 1; + * } + * B[x] = A[x]; + * A[x] = C[x]; + */ + + stmt = registerize(stmt); + + /* + * int A_1 = B[x]; + * C[x] = A_1; + * if (x<5 ? 1 : 0) { + * A_1 = A_1 + 1; + * } + * B[x] = A_1; + * A_1 = C[x]; + * A[x] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = B[x]; +# CHECK: C[x] = A_1; +# CHECK: if ( +# CHECK: A_1 = A_1 + 1; +# CHECK: } +# CHECK: B[x] = A_1; +# CHECK: A_1 = C[x]; +# CHECK: A[x] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// An example where an access is cut by an overlapping access inside a +// condition, and both sides are large enough to be registerized but cannot be +// because there is no safe place to put the initializer or finalizer. +void testRegisterizerConditionInsideOverlap1() { + KernelScope kernel_scope; + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + BufHandle c("C", {5}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + Stmt* stmt = Block::make( + {Store::make(a, {x}, Load::make(b, {x}, 1), 1), + Store::make(c, {x}, Load::make(a, {x}, 1), 1), + Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Block::make({ + Store::make(a, {x}, Add::make(Load::make(a, {x}, 1), 1), 1), + Store::make(a, {0}, 3, 1), + Store::make(a, {x}, Add::make(Load::make(a, {x}, 1), 1), 1), + }), + nullptr), + Store::make(b, {x}, Load::make(a, {x}, 1), 1), + Store::make(a, {x}, Load::make(c, {x}, 1), 1)}); + + /* + * A[x] = B[x]; + * C[x] = A[x]; + * if (x<5 ? 1 : 0) { + * A[x] = (A[x]) + 1; + * A[0] = 3; + * A[x] = (A[x]) + 1; + * } + * B[x] = A[x]; + * A[x] = C[x]; + */ + + // The A[0] store overlaps, A[x] cutting the region that can be registerized + // into two groups. + // Each group has 2 loads and 2 stores however, so we could registerize it, + // but the first group would need to be finalized inside the condition block, + // the second would need to be initialized inside the condition block. There's + // no safe place to put these that's visible to the other uses in the group + // and so neither registerization is possible. + + std::ostringstream before; + before << *stmt; + + // No change. + stmt = registerize(stmt); + + std::ostringstream after; + after << *stmt; + + ASSERT_EQ(before.str(), after.str()); +} + +// Same as the above, but the access group before the condition (and after the +// condition) are large enough to be registerized without needing the access +// from the loop. Registerization occurs but does not include any accesses in +// the condition, and the first group must be finalized before the Cond, the +// second initialized after it. +void testRegisterizerConditionInsideOverlap2() { + KernelScope kernel_scope; + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + BufHandle c("C", {5}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + Stmt* stmt = Block::make( + {Store::make(a, {x}, Load::make(b, {x}, 1), 1), + Store::make(a, {x}, Load::make(b, {x + 1}, 1), 1), + Store::make(c, {x}, Load::make(a, {x}, 1), 1), + Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Block::make({ + Store::make(a, {x}, Add::make(Load::make(a, {x}, 1), 1), 1), + Store::make(a, {0}, 3, 1), + Store::make(a, {x}, Add::make(Load::make(a, {x}, 1), 1), 1), + }), + nullptr), + Store::make(b, {x}, Load::make(a, {x}, 1), 1), + Store::make(b, {x + 1}, Load::make(a, {x}, 1), 1), + Store::make(a, {x}, Load::make(c, {x}, 1), 1)}); + + /* + * A[x] = B[x]; + * A[x] = B[x + 1]; + * C[x] = A[x]; + * if (x<5 ? 1 : 0) { + * A[x] = (A[x]) + 1; + * A[0] = 3; + * A[x] = (A[x]) + 1; + * } + * B[x] = A[x]; + * B[x + 1] = A[x]; + * A[x] = C[x]; + */ + + stmt = registerize(stmt); + + /* + * int A_1 = B[x]; // A_1 initializer + * A_1 = B[x + 1]; // + * C[x] = A_1; // + * A[x] = A_1; // A_1 finalizer + * if (x<5 ? 1 : 0) { + * A[x] = (A[x]) + 1; + * A[0] = 3; + * A[x] = (A[x]) + 1; + * } + * int A_2 = A[x]; // A_2 initialier + * B[x] = A_2; // + * B[x + 1] = A_2; // + * A_2 = C[x]; // + * A[x] = A_2; // A_2 finalizer + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = B[x]; +# CHECK: A_1 = B[x + 1]; +# CHECK: C[x] = A_1; +# CHECK: A[x] = A_1; +# CHECK: if ( +# CHECK-NOT: A_1 = A_1 + 1; +# CHECK: A[x] = (A[x] +# CHECK: A[0] = +# CHECK: A[x] = (A[x] +# CHECK: } +# CHECK: int A_2 = A[x]; +# CHECK: B[x] = A_2; +# CHECK: B[x + 1] = A_2; +# CHECK: A_2 = C[x]; +# CHECK: A[x] = A_2;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// When accesses are within conditional blocks they are not visible to the wider +// program, because we don't know if the branch would be taken and if it isn't +// the accesses in it don't need to be valid (think size checks on the index). +// In this case the accesses cannot be registerized. +void testRegisterizerConditionHidden() { + KernelScope kernel_scope; + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + BufHandle c("C", {5}, kInt); + VarHandle x("x", kInt); + + Stmt* stmt = Block::make( + {Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Store::make(a, {x}, Add::make(Load::make(a, {x}, 1), 1), 1), + nullptr), + Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kGT), + Store::make(a, {x}, Add::make(Load::make(a, {x}, 1), 1), 1), + nullptr)}); + + /* + * if (x<5 ? 1 : 0) { + * A[x] = (A[x]) + 1; + * } + * if (x>5 ? 1 : 0) { + * A[x] = (A[x]) + 1; + * } + */ + + std::ostringstream before; + before << *stmt; + + // No change. + stmt = registerize(stmt); + + std::ostringstream after; + after << *stmt; + + ASSERT_EQ(before.str(), after.str()); +} + +// But... if the same access is found in a non conditional scope, that means +// that that access is valid in the higher scope (or at least if its not it's +// the user's fault). It "unhides" the conditional accesses, allowing +// registerization to occur. +void testRegisterizerConditionUnhidden() { + KernelScope kernel_scope; + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + BufHandle c("C", {5}, kInt); + VarHandle x("x", kInt); + + Stmt* stmt = Block::make( + {Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Store::make(a, {x}, Add::make(Load::make(a, {x}, 1), 1), 1), + nullptr), + Store::make(a, {x}, Add::make(Load::make(a, {x}, 1), 1), 1), + Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kGT), + Store::make(a, {x}, Add::make(Load::make(a, {x}, 1), 1), 1), + nullptr)}); + + /* + * if (x<5 ? 1 : 0) { + * A[x] = (A[x]) + 1; + * } + * A[x] = (A[x]) + 1; <-- this is doing the unhiding. + * if (x>5 ? 1 : 0) { + * A[x] = (A[x]) + 1; + * } + */ + + stmt = registerize(stmt); + + /* + * int A_1 = A[x]; + * if (x<5 ? 1 : 0) { + * A_1 = A_1 + 1; + * } + * A_1 = A_1 + 1; + * if (x>5 ? 1 : 0) { + * A_1 = A_1 + 1; + * } + * A[x] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = A[x]; +# CHECK: if (x<5 +# CHECK: A_1 = A_1 + 1; +# CHECK: } +# CHECK: A_1 = A_1 + 1; +# CHECK: if (x>5 +# CHECK: A_1 = A_1 + 1; +# CHECK: } +# CHECK: A[x] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Can registerize a load that occurs in the condition of a Cond. +void testRegisterizerCondCondition() { + KernelScope kernel_scope; + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + BufHandle c("C", {5}, kInt); + VarHandle x("x", kInt); + + Stmt* stmt = Block::make( + {Store::make(a, {x}, Load::make(b, {x}, 1), 1), + Store::make(c, {x}, Load::make(a, {x}, 1), 1), + Cond::make( + CompareSelect::make( + Load::make(a, {x}, 1), 5, CompareSelectOperation::kLT), + Store::make(c, {x}, Add::make(Load::make(c, {x}, 1), 1), 1), + nullptr)}); + + /* + * A[x] = B[x]; + * C[x] = A[x]; + * if ((A[x])<5 ? 1 : 0) { + * C[x] = (C[x]) + 1; + * } + */ + + stmt = registerize(stmt); + + /* + * int A_1 = B[x]; + * int C_1 = A_1; + * if (A_1<5 ? 1 : 0) { + * C_1 = C_1 + 1; + * } + * C[x] = C_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = B[x]; +# CHECK: int C_1 = A_1; +# CHECK: if (A_1<5 +# CHECK: C_1 = C_1 + 1; +# CHECK: C[x] = C_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Appearing in the condition of a Cond makes it visible to the enclosing scope, +// and so we can registerize internal usages. +void testRegisterizerCondConditionUnhidden() { + KernelScope kernel_scope; + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + BufHandle c("C", {5}, kInt); + VarHandle x("x", kInt); + + Stmt* stmt = Block::make({Cond::make( + CompareSelect::make( + Load::make(a, {x}, 1), 5, CompareSelectOperation::kLT), + Store::make(a, {x}, Add::make(Load::make(a, {x}, 1), 1), 1), + Store::make(a, {x}, Add::make(Load::make(a, {x}, 1), 10), 1))}); + + /* + * if ((A[x])<5 ? 1 : 0) { + * A[x] = (A[x]) + 1; + * } else { + * A[x] = (A[x]) + 10; + * } + */ + + stmt = registerize(stmt); + + /* + * int A_1 = A[x]; + * if (A_1<5 ? 1 : 0) { + * A_1 = A_1 + 1; + * } else { + * A_1 = A_1 + 10; + * } + * A[x] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = A[x]; +# CHECK: if (A_1<5 +# CHECK: A_1 = A_1 + 1; +# CHECK: } else { +# CHECK: A_1 = A_1 + 10; +# CHECK: } +# CHECK: A[x] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Conditional hiding also works for IfThenElse exprs. +void testRegisterizerIfThenElseHidden() { + KernelScope kernel_scope; + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + BufHandle c("C", {5}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + Stmt* stmt = Block::make( + {Store::make( + b, + {y}, + IfThenElse::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Add::make(Load::make(a, {x}, 1), 1), + Add::make(Load::make(a, {x + 1}, 1), 2)), + 1), + Store::make( + b, + {y + 1}, + IfThenElse::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Add::make(Load::make(a, {x}, 1), 1), + Add::make(Load::make(a, {x + 1}, 1), 2)), + 1)}); + + /* + * B[y] = IfThenElse(x<5 ? 1 : 0, (A[x]) + 1, (A[x + 1]) + 2); + * B[y + 1] = IfThenElse(x<5 ? 1 : 0, (A[x]) + 1, (A[x + 1]) + 2); + */ + + std::ostringstream before; + before << *stmt; + + // No change. + stmt = registerize(stmt); + + std::ostringstream after; + after << *stmt; + + ASSERT_EQ(before.str(), after.str()); +} + +// Conditional unhiding also works for IfThenElse exprs. +void testRegisterizerIfThenElseUnhidden() { + KernelScope kernel_scope; + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + BufHandle c("C", {5}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + Stmt* stmt = Block::make({ + Store::make(a, {x}, 0, 1), + Store::make( + b, + {y}, + IfThenElse::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Add::make(Load::make(a, {x}, 1), 1), + Add::make(Load::make(a, {x + 1}, 1), 2)), + 1), + Store::make( + b, + {y + 1}, + IfThenElse::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Add::make(Load::make(a, {x}, 1), 1), + Add::make(Load::make(a, {x + 1}, 1), 2)), + 1), + }); + + /* + * A[x] = 0; + * B[y] = IfThenElse(x<5 ? 1 : 0, (A[x]) + 1, (A[x + 1]) + 2); + * B[y + 1] = IfThenElse(x<5 ? 1 : 0, (A[x]) + 1, (A[x + 1]) + 2); + */ + + stmt = registerize(stmt); + + /* + * int A_1 = 0; + * B[y] = IfThenElse(x<5 ? 1 : 0, A_1 + 1, (A[x + 1]) + 2); + * B[y + 1] = IfThenElse(x<5 ? 1 : 0, A_1 + 1, (A[x + 1]) + 2); + * A[x] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = 0; +# CHECK: B[y] = IfThenElse(x<5 ? 1 : 0, A_1 + 1, (A[x + 1]) + 2); +# CHECK: B[y + 1] = IfThenElse(x<5 ? 1 : 0, A_1 + 1, (A[x + 1]) + 2); +# CHECK: A[x] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Nested IfThenElse exprs can't promote to higher level scopes. +void testRegisterizerIfThenElseNested() { + KernelScope kernel_scope; + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + BufHandle c("C", {5}, kInt); + BufHandle d("D", {5}, kInt); + VarHandle x("x", kInt); + + Stmt* stmt = Block::make({Store::make( + a, + {x}, + IfThenElse::make( + CompareSelect::make(x, 3, CompareSelectOperation::kLT), + IfThenElse::make( + CompareSelect::make(x, 2, CompareSelectOperation::kEQ), + Load::make(d, {x}, 1), + Load::make(b, {x}, 1)), + IfThenElse::make( + CompareSelect::make(x, 5, CompareSelectOperation::kEQ), + Load::make(c, {x}, 1), + Load::make(d, {x}, 1))), + 1)}); + + /* + * A[x] = IfThenElse(x<3 ? 1 : 0, + * IfThenElse(x==2 ? 1 : 0, D[x], B[x]), + * IfThenElse(x==5 ? 1 : 0, C[x], D[x])); + */ + + std::ostringstream before; + before << *stmt; + + // No change. + stmt = registerize(stmt); + + std::ostringstream after; + after << *stmt; + + ASSERT_EQ(before.str(), after.str()); +} + +// Cannot registerize an access completely contained within an IfThenElse +// branch, since it is not a Stmt and cannot hold variable definitions. We need +// to check that we don't promote the initializer/finalizer to the enclosing +// Block. +void testRegisterizerIfThenElseInternal() { + KernelScope kernel_scope; + // Making these floats so they don't get simplified to a single access. + BufHandle a("A", {5}, kFloat); + BufHandle b("B", {5}, kFloat); + VarHandle x("x", kInt); + + Stmt* stmt = Block::make({Store::make( + a, + {x}, + IfThenElse::make( + CompareSelect::make(x, 3, CompareSelectOperation::kLT), + Add::make(Load::make(b, {x}, 1), Load::make(b, {x}, 1)), + Load::make(b, {x}, 1)), + 1)}); + + /* + * A[x] = IfThenElse(x<3 ? 1 : 0, (B[x]) + (B[x]), B[x]); + */ + + std::ostringstream before; + before << *stmt; + + // No change. + stmt = registerize(stmt); + + std::ostringstream after; + after << *stmt; + + ASSERT_EQ(before.str(), after.str()); + + // If this was a Cond instead of an IfThenElse then we could registerize the + // two accesses to B[x] in the True branch. + + // Actually lets verify that. + + stmt = Block::make({Cond::make( + CompareSelect::make(x, 3, CompareSelectOperation::kLT), + Store::make( + a, {x}, Add::make(Load::make(b, {x}, 1), Load::make(b, {x}, 1)), 1), + Store::make(a, {x}, Load::make(b, {x}, 1), 1))}); + + /* + * if (x<3 ? 1 : 0) { + * A[x] = (B[x]) + (B[x]); + * } else { + * A[x] = B[x]; + * } + */ + + stmt = registerize(stmt); + + /* + * if (x<3 ? 1 : 0) { + * float B_1 = B[x]; + * A[x] = B_1 + B_1; + * } else { + * A[x] = B[x]; + * } + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK-NOT: int +# CHECK-NOT: float +# CHECK: if (x<3 +# CHECK: float B_1 = +# CHECK: A[x] = B_1 + B_1 +# CHECK: } else { +# CHECK: A[x] = B[x] +# CHECK: } +# CHECK-NOT: A[x] +# CHECK-NOT: B[x])IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Can registerize a load that occurs in the condition of an IfThenElse; +void testRegisterizerIfThenElseCondition() { + KernelScope kernel_scope; + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + BufHandle c("C", {5}, kInt); + VarHandle x("x", kInt); + + Stmt* stmt = Block::make( + {Store::make(a, {x}, Load::make(a, {x}, 1), 1), + Store::make( + a, + {x}, + IfThenElse::make( + CompareSelect::make( + Load::make(a, {x}, 1), 5, CompareSelectOperation::kLT), + Load::make(b, {0}, 1), + Load::make(c, {0}, 1)), + 1)}); + + /* + * A[x] = A[x]; <---- just here so there are enough accesses to combine. + * A[x] = IfThenElse((A[x])<5 ? 1 : 0, B[0], C[0]); + */ + + stmt = registerize(stmt); + + /* + * int A_1 = A[x]; + * A_1 = A_1; + * A_1 = IfThenElse(A_1<5 ? 1 : 0, B[0], C[0]); + * A[x] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = A[x]; +# CHECK: A_1 = IfThenElse(A_1<5 ? 1 : 0, B[0], C[0]); +# CHECK: A[x] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Appearing in the condition of a Cond makes it visible to the enclosing scope, +// and so we can registerize internal usages. +void testRegisterizerIfThenElseConditionUnhidden() { + KernelScope kernel_scope; + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + BufHandle c("C", {5}, kInt); + VarHandle x("x", kInt); + + Stmt* stmt = Block::make({Store::make( + b, + {x}, + IfThenElse::make( + CompareSelect::make( + Load::make(a, {x}, 1), 5, CompareSelectOperation::kLT), + Add::make(Load::make(a, {x}, 1), 1), + Add::make(Load::make(a, {x}, 1), 10)), + 1)}); + + /* + * B[x] = IfThenElse((A[x])<5 ? 1 : 0, (A[x]) + 1, (A[x]) + 10); + */ + + stmt = registerize(stmt); + + /* + * int A_1 = A[x]; + * B[x] = IfThenElse(A_1<5 ? 1 : 0, A_1 + 1, A_1 + 10); + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = A[x]; +# CHECK: B[x] = IfThenElse(A_1<5 ? 1 : 0, A_1 + 1, A_1 + 10);)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Cannot promote accesses internal to IfThenElse branches even if the enclosing +// scope if conditional. +void testRegisterizerConditionBranchOnly() { + KernelScope kernel_scope; + BufHandle a("A", {5}, kInt); + VarHandle x("x", kInt); + Stmt* stmt = Block::make({For::make( + x, + 0, + 10, + Block::make({ + Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Store::make( + a, + {x}, + IfThenElse::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Add::make(Load::make(a, {x}, 1), x), + Add::make(Load::make(a, {x - 5}, 1), x)), + 1), + Store::make( + a, + {x - 5}, + IfThenElse::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Add::make(Load::make(a, {x}, 1), x), + Add::make(Load::make(a, {x - 5}, 1), x)), + 1)), + }))}); + + std::ostringstream before; + before << *stmt; + + /* for (int x = 0; x < 10; x++) { + * if (x<5 ? 1 : 0) { + * A[x] = IfThenElse(x<5 ? 1 : 0, (A[x]) + x, (A[x - 5]) + x); + * } else { + * A[x - 5] = IfThenElse(x<5 ? 1 : 0, (A[x]) + x, (A[x - 5]) + x); + * } + * } + */ + + // No change. + stmt = registerize(stmt); + + std::ostringstream after; + after << *stmt; + + ASSERT_EQ(before.str(), after.str()); +} + +// We can registerize an IfThenElse that appears in the condition branch of a +// Cond. This is a weird but valid thing to do. +void testRegisterizerCondIfThenElse() { + KernelScope kernel_scope; + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + BufHandle c("C", {5}, kInt); + VarHandle x("x", kInt); + + Stmt* stmt = Block::make({Cond::make( + CompareSelect::make( + IfThenElse::make( + CompareSelect::make( + Load::make(a, {x}, 1), 5, CompareSelectOperation::kLT), + Load::make(a, {x}, 1), + Load::make(b, {x}, 1)), + x, + CompareSelectOperation::kEQ), + Store::make(c, {x}, Add::make(Load::make(c, {x}, 1), 1), 1), + nullptr)}); + + /* + * if ((IfThenElse((A[x])<5 ? 1 : 0, A[x], B[x]))==x ? 1 : 0) { + * C[x] = (C[x]) + 1; + * } + */ + + stmt = registerize(stmt); + + // access to A can be registerized, but not B or C + + /* + * int A_1 = A[x]; + * if ((IfThenElse(A_1<5 ? 1 : 0, A_1, B[x]))==x ? 1 : 0) { + * C[x] = (C[x]) + 1; + * } + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = A[x]; +# CHECK: if ((IfThenElse(A_1<5 ? 1 : 0, A_1, B[x] +# CHECK: C[x] = (C[x]) + 1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Can registerize a conditional access in the RHS of a store unhidden by it's +// LHS, and hoist it out of a loop. +void testRegisterizerIfThenElseLoop() { + KernelScope kernel_scope; + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + Stmt* stmt = For::make( + y, + 0, + 10, + Store::make( + a, + {x}, + IfThenElse::make( + CompareSelect::make(x, 3, CompareSelectOperation::kLT), + Load::make(a, {x}, 1), + Load::make(b, {y}, 1)), + 1)); + + /* + * for (int y = 0; y < 10; y++) { + * A[x] = IfThenElse(x<3 ? 1 : 0, A[x], B[y]); + * } + */ + + stmt = registerize(stmt); + + /* + * int A_1 = A[x]; + * for (int y = 0; y < 10; y++) { + * A_1 = IfThenElse(x<3 ? 1 : 0, A_1, B[y]); + * } + * A[x] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = A[x]; +# CHECK: for ( +# CHECK: A_1 = IfThenElse(x<3 ? 1 : 0, A_1, B[y]); +# CHECK: } +# CHECK: A[x] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Cannot registerize if the RHS overlaps the access creating visibility. +void testRegisterizerIfThenElseLoopCut() { + KernelScope kernel_scope; + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + Stmt* stmt = Block::make({For::make( + y, + 0, + 10, + Store::make( + a, + {x}, + IfThenElse::make( + CompareSelect::make(x, 3, CompareSelectOperation::kLT), + Load::make(a, {x}, 1), + Load::make(a, {y}, 1)), + 1))}); + + /* + * for (int y = 0; y < 10; y++) { + * A[x] = IfThenElse(x<3 ? 1 : 0, A[x], A[y]); + * } + */ + + std::ostringstream before; + before << *stmt; + + // No change. + stmt = registerize(stmt); + + std::ostringstream after; + after << *stmt; + + ASSERT_EQ(before.str(), after.str()); +} + +// Simple case where an access is cut by an overlapping access later in the +// program, we can registerize up until the overlap. +void testRegisterizerPartialAfter() { + KernelScope kernel_scope; + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + Stmt* stmt = Block::make( + {Store::make(a, {0}, 0, 1), + For::make( + x, + 0, + 10, + Block::make( + {Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), x), 1)})), + For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1}, 1), 1))}); + + /* + * A[0] = 0; + * for (int x = 0; x < 10; x++) { + * A[0] = (A[0]) + x; + * } + * for (int x = 1; x < 10; x++) { + * A[x] = A[x - 1]; + * } + */ + + stmt = registerize(stmt); + + /* + * int A_1 = 0; + * for (int x = 0; x < 10; x++) { + * A_1 = A_1 + x; + * } + * A[0] = A_1; + * for (int x = 1; x < 10; x++) { + * A[x] = A[x - 1]; + * } + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = 0; +# CHECK: for ( +# CHECK: A_1 = A_1 + x; +# CHECK: } +# CHECK: A[0] = A_1; +# CHECK: for ( +# CHECK: A[x] = A[x - 1]; +# CHECK: } +# CHECK-NOT: A)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// We can registerize an access which overlaps a previous access, the +// initializer must be inserted after the previous access. +void testRegisterizerPartialBefore() { + KernelScope kernel_scope; + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + Stmt* stmt = Block::make( + {For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1}, 1), 1)), + Store::make(a, {0}, 0, 1), + For::make( + x, + 0, + 10, + Block::make({Store::make( + a, {0}, Add::make(Load::make(a, {0}, 1), x), 1)}))}); + + /* + * for (int x = 1; x < 10; x++) { + * A[x] = A[x - 1]; + * } + * A[0] = 0; + * for (int x = 0; x < 10; x++) { + * A[0] = (A[0]) + x; + * } + */ + + stmt = registerize(stmt); + + /* + * for (int x = 1; x < 10; x++) { + * A[x] = A[x - 1]; + * } + * int A_1 = 0; + * for (int x = 0; x < 10; x++) { + * A_1 = A_1 + x; + * } + * A[0] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK-NOT: int +# CHECK: for ( +# CHECK: A[x] = A[x - 1]; +# CHECK: } +# CHECK: int A_1 = 0; +# CHECK: for ( +# CHECK: A_1 = A_1 + x; +# CHECK: } +# CHECK: A[0] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// The combination of the previous two tests, an access is cut by an overlapping +// access in both directions. +void testRegisterizerPartialInside() { + KernelScope kernel_scope; + BufHandle a("A", {1}, kInt); + VarHandle x1("x1", kInt); + VarHandle x2("x2", kInt); + VarHandle x3("x3", kInt); + Stmt* stmt = Block::make( + {Store::make(a, {0}, 2, 1), + For::make( + x1, + 0, + 10, + Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), x1), 1)), + For::make( + x2, 1, 10, Store::make(a, {x2}, Load::make(a, {x2 - 1}, 1), 1)), + For::make( + x3, + 0, + 10, + Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), x3), 1))}); + + /* + * A[0] = 2; + * for (int x1 = 0; x1 < 10; x1++) { + * A[0] = (A[0]) + x1; + * } + * for (int x2 = 1; x2 < 10; x2++) { + * A[x2] = A[x2 - 1]; + * } + * for (int x3 = 0; x3 < 10; x3++) { + * A[0] = (A[0]) + x3; + * } + */ + + stmt = registerize(stmt); + + /* + * int A_1 = 2; + * for (int x1 = 0; x1 < 10; x1++) { + * A_1 = A_1 + x1; + * } + * A[0] = A_1; + * for (int x2 = 1; x2 < 10; x2++) { + * A[x2] = A[x2 - 1]; + * } + * int A_2 = A[0]; + * for (int x3 = 0; x3 < 10; x3++) { + * A_2 = A_2 + x3; + * } + * A[0] = A_2; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = 2; +# CHECK: for ( +# CHECK: A_1 = A_1 + x1; +# CHECK: } +# CHECK: A[0] = A_1; +# CHECK: for ( +# CHECK: A[x2] = +# CHECK: } +# CHECK: int A_2 = A[0]; +# CHECK: for ( +# CHECK: A_2 = A_2 + x3; +# CHECK: } +# CHECK: A[0] = A_2;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// An element could be registerized program wide but is cut by a conditional +// access, we should break this into two scalars and write back to the buffer +// before the condition. +void testRegisterizerPartialCondition() { + KernelScope kernel_scope; + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + Stmt* stmt = Block::make( + {Store::make(a, {0}, 2, 1), + For::make( + x, + 0, + 10, + Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), x), 1)), + Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Store::make(a, {x}, Load::make(a, {x - 1}, 1), 1), + nullptr), + For::make( + x, + 0, + 10, + Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), x), 1))}); + + /* + * A[0] = 2; + * for (int x = 0; x < 10; x++) { + * A[0] = (A[0]) + x; + * } + * if (x<5 ? 1 : 0) { + * A[x] = A[x - 1]; + * } + * for (int x = 0; x < 10; x++) { + * A[0] = (A[0]) + x; + * } + */ + + stmt = registerize(stmt); + + /* + * int A_1 = 2; + * for (int x = 0; x < 10; x++) { + * A_1 = A_1 + x; + * } + * A[0] = A_1; + * if (x<5 ? 1 : 0) { + * A[x] = A[x - 1]; + * } + * int A_2 = A[0]; + * for (int x = 0; x < 10; x++) { + * A_2 = A_2 + x; + * } + * A[0] = A_2; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = 2; +# CHECK: for ( +# CHECK: A_1 = A_1 + x; +# CHECK: } +# CHECK: A[0] = A_1; +# CHECK: if ( +# CHECK: A[x] = +# CHECK: } +# CHECK: int A_2 = A[0]; +# CHECK: for ( +# CHECK: A_2 = A_2 + x; +# CHECK: } +# CHECK: A[0] = A_2;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Tests case where an access is cut by an internal conditional access which +// itself is registerized. +void testRegisterizerPartialConditionInternalCut() { + KernelScope kernel_scope; + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + Stmt* stmt = Block::make( + {Store::make(a, {0}, 1, 1), + Store::make(a, {0}, 3, 1), + Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Block::make({Store::make(a, {x}, 1, 1), Store::make(a, {x}, 3, 1)}), + nullptr), + Store::make(a, {0}, 4, 1), + Store::make(a, {0}, 6, 1)}); + + /* + * A[0] = 1; + * A[0] = 3; + * if (x<5 ? 1 : 0) { + * A[x] = 1; + * A[x] = 3; + * } + * A[0] = 4; + * A[0] = 6; + */ + + stmt = registerize(stmt); + + /* + * int A_1 = 1; + * A_1 = 3; + * A[0] = A_1; + * if (x<5 ? 1 : 0) { + * int A_2 = 1; + * A_2 = 3; + * A[x] = A_2; + * } + * int A_3 = 4; + * A_3 = 6; + * A[0] = A_3; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = 1; +# CHECK: A_1 = 3 +# CHECK: A[0] = A_1; +# CHECK: if ( +# CHECK: int A_2 = 1; +# CHECK: A_2 = 3; +# CHECK: A[x] = A_2; +# CHECK: } +# CHECK: int A_3 = 4; +# CHECK: A_3 = 6; +# CHECK: A[0] = A_3;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// First statment in condition closes outer access, but can be registerized with +// later statements. +void testRegisterizerPartialConditionInternalStart() { + KernelScope kernel_scope; + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + Stmt* stmt = Block::make( + {Store::make(a, {0}, 1, 1), + Store::make(a, {0}, 3, 1), + Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Block::make({Store::make(a, {x}, 1, 1), Store::make(a, {x}, 3, 1)}), + nullptr), + Store::make(a, {x}, 4, 1), + Store::make(a, {x}, 6, 1)}); + + /* + * A[0] = 1; + * A[0] = 3; + * if (x<5 ? 1 : 0) { + * A[x] = 1; + * A[x] = 3; + * } + * A[x] = 4; + * A[x] = 6; + */ + + stmt = registerize(stmt); + + /* + * int A_1 = 1; + * A_1 = 3; + * A[0] = A_1; + * int A_2 = A[x]; <--- must read from the input here. + * if (x<5 ? 1 : 0) { + * A_2 = 1; + * A_2 = 3; + * } + * A_2 = 4; + * A_2 = 6; + * A[x] = A_2; + */ + + // TODO: I suppose we could refactor with a conditional initializier? + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = 1; +# CHECK: A_1 = 3 +# CHECK: A[0] = A_1; +# CHECK: int A_2 = A[x]; +# CHECK: if ( +# CHECK: A_2 = 1; +# CHECK: A_2 = 3; +# CHECK: } +# CHECK: A_2 = 4; +# CHECK: A_2 = 6; +# CHECK: A[x] = A_2;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// An access cuts two open overlaps and creates four scalar variables. +void testRegisterizerPartialOverlapsTwo() { + KernelScope kernel_scope; + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + Stmt* stmt = Block::make({Store::make(a, {1}, Load::make(a, {0}, 1), 1), + Store::make(a, {0}, Load::make(a, {1}, 1), 1), + Store::make(a, {0}, Load::make(a, {1}, 1), 1), + For::make(x, 1, 10, Store::make(a, {x}, x, 1)), + Store::make(a, {1}, Load::make(a, {0}, 1), 1), + Store::make(a, {0}, Load::make(a, {1}, 1), 1), + Store::make(a, {0}, Load::make(a, {1}, 1), 1)}); + + /* + * A[1] = A[0]; + * A[0] = A[1]; + * A[0] = A[1]; + * for (int x = 1; x < 10; x++) { + * A[x] = x; + * } + * A[1] = A[0]; + * A[0] = A[1]; + * A[0] = A[1]; + */ + + stmt = registerize(stmt); + + /* + * int A_1 = A[0]; + * int A_2 = A_1; + * A_1 = A_2; + * A_1 = A_2; + * A[1] = A_2; + * A[0] = A_1; + * for (int x = 1; x < 10; x++) { + * A[x] = x; + * } + * int A_3 = A[0]; + * int A_4 = A_3; + * A_3 = A_4; + * A_3 = A_4; + * A[1] = A_4; + * A[0] = A_3; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = A[0]; +# CHECK: int A_2 = A_1; +# CHECK: A_1 = A_2; +# CHECK: A_1 = A_2; +# CHECK: A[1] = A_2; +# CHECK: A[0] = A_1; +# CHECK: for ( +# CHECK: A[x] = x; +# CHECK: } +# CHECK: int A_3 = A[0]; +# CHECK: int A_4 = A_3; +# CHECK: A_3 = A_4; +# CHECK: A_3 = A_4; +# CHECK: A[1] = A_4; +# CHECK: A[0] = A_3;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Nested blocks will automatically be flattened and do not provent +// registerization of enclosed accesses. +void testRegisterizerNestedBlocks() { + KernelScope kernel_scope; + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + Stmt* stmt = Block::make( + {Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), 1), 1), + Block::make( + {Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), 2), 1)}), + Block::make({Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), 3), 1), + Block::make({Store::make( + a, {0}, Add::make(Load::make(a, {0}, 1), 4), 1)})})}); + + /* + * A[0] = (A[0]) + 1; + * { + * A[0] = (A[0]) + 2; + * } + * { + * A[0] = (A[0]) + 3; + * { + * A[0] = (A[0]) + 4; + * } + * } + */ + + stmt = registerize(stmt); + + /* + * int A_1 = A[0]; + * A_1 = A_1 + 1; + * A_1 = A_1 + 2; + * A_1 = A_1 + 3; + * A_1 = A_1 + 4; + * A[0] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = A[0]; +# CHECK: A_1 = A_1 + 1; +# CHECK: A_1 = A_1 + 2; +# CHECK: A_1 = A_1 + 3; +# CHECK: A_1 = A_1 + 4; +# CHECK: A[0] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// The access can be registerized internally to a condition, but must ensure +// that both initializer and finalizer are within the same condition. +void testRegisterizerNestedConditions() { + KernelScope kernel_scope; + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + Stmt* stmt = Block::make({Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Block::make( + {Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), 1), 1), + Cond::make( + CompareSelect::make(x, 2, CompareSelectOperation::kEQ), + Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), 1), 1), + nullptr)}), + nullptr)}); + + /* + * if (x<5 ? 1 : 0) { + * A[0] = (A[0]) + 1; + * if (x==2 ? 1 : 0) { + * + * A[0] = (A[0]) + 1; + * } + * } + */ + + stmt = registerize(stmt); + + /* + * if (x<5 ? 1 : 0) { + * int A_1 = A[0]; + * A_1 = A_1 + 1; + * if (x==2 ? 1 : 0) { + * A_1 = A_1 + 1; + * } + * A[0] = A_1; + * } + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: if (x<5 +# CHECK: int A_1 = A[0]; +# CHECK: A_1 = A_1 + 1; +# CHECK: if (x==2 +# CHECK: A_1 = A_1 + 1; +# CHECK: } +# CHECK: A[0] = A_1; +# CHECK: })IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// If an access exists outside the scope of the condition then we can lift +// nested conditional usages into the same scalar. +void testRegisterizerNestedConditionsUnhidden() { + KernelScope kernel_scope; + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + Stmt* stmt = Block::make( + {Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), 1), 1), + Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Block::make( + {Store::make(a, {1}, 1, 1), + Cond::make( + CompareSelect::make(x, 2, CompareSelectOperation::kEQ), + Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), 1), 1), + nullptr)}), + nullptr)}); + + /* + * A[0] = (A[0]) + 1; + * if (x<5 ? 1 : 0) { + * A[1] = 1; + * if (x==2 ? 1 : 0) { + * A[0] = (A[0]) + 1; + * } + * } + */ + + stmt = registerize(stmt); + + /* + * int A_1 = A[0]; + * A_1 = A_1 + 1; + * if (x<5 ? 1 : 0) { + * A[1] = 1; + * if (x==2 ? 1 : 0) { + * A_1 = A_1 + 1; + * } + * } + * A[0] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = A[0]; +# CHECK: A_1 = A_1 + 1; +# CHECK: if (x<5 +# CHECK: A[1] = 1; +# CHECK: if (x==2 +# CHECK: A_1 = A_1 + 1; +# CHECK: A[0] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +void testRegisterizerNestedConditionsHiddenFirst() { + KernelScope kernel_scope; + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + Stmt* stmt = Block::make( + {Cond::make( + CompareSelect::make(x, 2, CompareSelectOperation::kEQ), + Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), 1), 1), + nullptr), + Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Block::make({Cond::make( + CompareSelect::make(x, 2, CompareSelectOperation::kEQ), + Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), 1), 1), + nullptr)}), + nullptr)}); + + /* + * if (x==2 ? 1 : 0) { + * A[0] = (A[0]) + 1; + * } + * if (x<5 ? 1 : 0) { + * if (x==2 ? 1 : 0) { + * A[0] = (A[0]) + 1; + * } + * } + */ + + std::ostringstream before; + before << *stmt; + + // No change. + stmt = registerize(stmt); + + std::ostringstream after; + after << *stmt; + + ASSERT_EQ(before.str(), after.str()); + + stmt = registerize(stmt); +} + +void testRegisterizerNestedConditionsHiddenSecond() { + KernelScope kernel_scope; + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + Stmt* stmt = Block::make( + {Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Block::make({Cond::make( + CompareSelect::make(x, 2, CompareSelectOperation::kEQ), + Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), 1), 1), + nullptr)}), + nullptr), + Cond::make( + CompareSelect::make(x, 2, CompareSelectOperation::kEQ), + Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), 1), 1), + nullptr)}); + + /* + * if (x<5 ? 1 : 0) { + * if (x==2 ? 1 : 0) { + * A[0] = (A[0]) + 1; + * } + * } + * if (x==2 ? 1 : 0) { + * A[0] = (A[0]) + 1; + * } + */ + + std::ostringstream before; + before << *stmt; + + // No change. + stmt = registerize(stmt); + + std::ostringstream after; + after << *stmt; + + ASSERT_EQ(before.str(), after.str()); + + stmt = registerize(stmt); +} + +// If an access is cut by another access internal to a condition block, it still +// cuts the access. +void testRegisterizerNestedConditionsCut() { + KernelScope kernel_scope; + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + Stmt* stmt = Block::make( + {Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), 1), 1), + Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Block::make( + {Store::make(a, {x}, 1, 1), + Cond::make( + CompareSelect::make(x, 2, CompareSelectOperation::kEQ), + Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), 1), 1), + nullptr)}), + nullptr)}); + + /* + * A[0] = (A[0]) + 1; + * if (x<5 ? 1 : 0) { + * A[x] = 1; + * if (x==2 ? 1 : 0) { + * + * A[0] = (A[0]) + 1; + * } + * } + */ + + std::ostringstream before; + before << *stmt; + + // No change. + stmt = registerize(stmt); + + std::ostringstream after; + after << *stmt; + + ASSERT_EQ(before.str(), after.str()); +} + +void testRegisterizerNestedConditionLoopHidden() { + KernelScope kernel_scope; + BufHandle a("A", {10}, kInt); + BufHandle b("B", {10}, kInt); + VarHandle x("x", kInt); + Stmt* stmt = Block::make( + {Cond::make( + CompareSelect::make(x, 2, CompareSelectOperation::kEQ), + Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), 1), 1), + nullptr), + For::make( + x, + 0, + 10, + Block::make( + {Store::make(b, {x}, 0, 1), + Cond::make( + CompareSelect::make(x, 2, CompareSelectOperation::kEQ), + Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), 1), 1), + nullptr)}))}); + + /* + * if (x==2 ? 1 : 0) { + * A[0] = (A[0]) + 1; + * } + * for (int x = 0; x < 10; x++) { + * B[x] = 0; <-- this is only here to prevent Loop/Cond reordering. + * if (x==2 ? 1 : 0) { + * A[0] = (A[0]) + 1; + * } + * } + */ + + std::ostringstream before; + before << *stmt; + + // No change. + stmt = registerize(stmt); + + std::ostringstream after; + after << *stmt; + + ASSERT_EQ(before.str(), after.str()); +} + +// Three loops and four element regions, three of which should be registerized +// at different levels of the IR. +void testRegisterizerNestedConditionThreeDeep() { + KernelScope kernel_scope; + BufHandle a("A", {10}, kInt); + BufHandle b("B", {10}, kInt); + VarHandle x("x", kInt); + Stmt* stmt = Block::make( + {Store::make(a, {4}, 0, 1), + Cond::make( + CompareSelect::make(x, 2, CompareSelectOperation::kGT), + Cond::make( + CompareSelect::make(x, 3, CompareSelectOperation::kGT), + Block::make({ + Cond::make( + CompareSelect::make(x, 4, CompareSelectOperation::kGT), + Block::make({ + Store::make( + a, {1}, Add::make(Load::make(a, {1}, 1), 1), 1), + Store::make( + a, {2}, Add::make(Load::make(a, {2}, 1), 1), 1), + Store::make( + a, {3}, Add::make(Load::make(a, {3}, 1), 1), 1), + Store::make( + a, {4}, Add::make(Load::make(a, {4}, 1), 1), 1), + Store::make( + a, {1}, Add::make(Load::make(a, {1}, 1), 1), 1), + }), + nullptr), + Store::make(a, {2}, Add::make(Load::make(a, {2}, 1), 1), 1), + }), + nullptr), + nullptr)}); + + /* + * A[4] = 0; + * if (x>2 ? 1 : 0) { + * if (x>3 ? 1 : 0) { + * if (x>4 ? 1 : 0) { + * A[1] = (A[1]) + 1; + * A[2] = (A[2]) + 1; + * A[3] = (A[3]) + 1; + * A[4] = (A[4]) + 1; + * A[1] = (A[1]) + 1; + * } + * A[2] = (A[2]) + 1; + * } + * } + */ + + stmt = registerize(stmt); + + /* + * int A_1 = 0; + * if (x>2 ? 1 : 0) { + * if (x>3 ? 1 : 0) { + * int A_3 = A[2]; + * if (x>4 ? 1 : 0) { + * int A_2 = A[1]; + * A_2 = A_2 + 1; + * A_3 = A_3 + 1; + * A[3] = (A[3]) + 1; + * A_1 = A_1 + 1; + * A_2 = A_2 + 1; + * A[1] = A_2; + * } + * A_3 = A_3 + 1; + * A[2] = A_3; + * } + * } + * A[4] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = 0; +# CHECK: if (x>2 ? 1 : 0) { +# CHECK: if (x>3 ? 1 : 0) { +# CHECK: int A_3 = A[2]; +# CHECK: if (x>4 ? 1 : 0) { +# CHECK: int A_2 = A[1]; +# CHECK: A_2 = A_2 + 1; +# CHECK: A_3 = A_3 + 1; +# CHECK: A[3] = (A[3]) + 1; +# CHECK: A_1 = A_1 + 1; +# CHECK: A_2 = A_2 + 1; +# CHECK: A[1] = A_2; +# CHECK: } +# CHECK: A_3 = A_3 + 1; +# CHECK: A[2] = A_3; +# CHECK: } +# CHECK: } +# CHECK: A[4] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Can replace a simple scalar access with a local variable even when that +// variable is an outer loop var. +void testRegisterizerNestedLoopSimple() { + KernelScope kernel_scope; + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + Stmt* stmt = Block::make({For::make( + y, + 0, + 10, + For::make( + x, + 0, + 10, + Block::make({Store::make( + a, {y}, Add::make(Load::make(a, {y}, 1), x), 1)})))}); + + /* + * for (int y = 0; y < 10; y++) { + * for (int x = 0; x < 10; x++) { + * A[y] = (A[y]) + x; + * } + * } + */ + + stmt = registerize(stmt); + + /* + * for (int y = 0; y < 10; y++) { + * int A_1 = A[y]; + * for (int x = 0; x < 10; x++) { + * A_1 = x + A_1; + * } + * A[y] = A_1; + * } + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: for (int y +# CHECK: int A_1 = A[y]; +# CHECK: for (int x +# CHECK: A_1 = x + A_1; +# CHECK: } +# CHECK: A[y] = A_1; +# CHECK: })IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Test the positive case of the hiddenAccess split, where an internal +// conditional access can be hoisted up through a loop to match an existing +// access in a higher scope and the two can be registerized. +void testRegisterizerHiddenAccessYes() { + KernelScope kernel_scope; + BufHandle a("A", {10}, kInt); + BufHandle b("B", {10}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + Stmt* stmt = Block::make({Cond::make( + CompareSelect::make(x, 2, CompareSelectOperation::kEQ), + Block::make( + {Store::make(a, {0}, 0, 1), + For::make( + x, + 0, + 10, + Block::make( + {Store::make(b, {x}, 0, 1), + Cond::make( + CompareSelect::make(x, 3, CompareSelectOperation::kEQ), + For::make( + y, + 0, + 10, + Store::make( + a, + {0}, + Add::make(Load::make(a, {0}, 1), 1), + 1)), + nullptr)}))}), + nullptr)}); + + /* + * if (x==2 ? 1 : 0) { + * A[0] = 0; + * for (int x = 0; x < 10; x++) { + * B[x] = 0; + * if (x==3 ? 1 : 0) { + * for (int y = 0; y < 10; y++) { + * A[0] = (A[0]) + 1; + * } + * } + * } + * } + */ + + stmt = registerize(stmt); + + /* + * if (x==2 ? 1 : 0) { + * int A_1 = 0; + * for (int x = 0; x < 10; x++) { + * B[x] = 0; + * if (x==3 ? 1 : 0) { + * for (int y = 0; y < 10; y++) { + * A_1 = A_1 + 1; + * } + * } + * } + * A[0] = A_1; + * } + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: if (x==2 +# CHECK: int A_1 = 0; +# CHECK: for (int x +# CHECK: B[x] = 0; +# CHECK: if (x==3 +# CHECK: for (int y +# CHECK: A_1 = A_1 + 1; +# CHECK: } +# CHECK: } +# CHECK: } +# CHECK: A[0] = A_1; +# CHECK: })IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Test the negative case of the hiddenAccess split, where the hoisted access is +// never unhidden at a higher scope and registerization occurs at the lower +// scope. +void testRegisterizerHiddenAccessNo() { + KernelScope kernel_scope; + BufHandle a("A", {10}, kInt); + BufHandle b("B", {10}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + Stmt* stmt = Block::make({Cond::make( + CompareSelect::make(x, 2, CompareSelectOperation::kEQ), + Block::make({For::make( + x, + 0, + 10, + Block::make( + {Store::make(b, {x}, 0, 1), + Cond::make( + CompareSelect::make(x, 3, CompareSelectOperation::kEQ), + For::make( + y, + 0, + 10, + Store::make( + a, {0}, Add::make(Load::make(a, {0}, 1), 1), 1)), + nullptr)}))}), + nullptr)}); + + /* + * if (x==2 ? 1 : 0) { + * A[0] = 0; + * for (int x = 0; x < 10; x++) { + * B[x] = 0; + * if (x==3 ? 1 : 0) { + * for (int y = 0; y < 10; y++) { + * A[0] = (A[0]) + 1; + * } + * } + * } + * } + */ + + stmt = registerize(stmt); + + /* + * if (x==2 ? 1 : 0) { + * for (int x = 0; x < 10; x++) { + * B[x] = 0; + * if (x==3 ? 1 : 0) { + * int A_1 = A[0]; + * for (int y = 0; y < 10; y++) { + * A_1 = A_1 + 1; + * } + * A[0] = A_1; + * } + * } + * } + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: if (x==2 +# CHECK: for (int x +# CHECK: B[x] = 0; +# CHECK: if (x==3 +# CHECK: int A_1 = A[0]; +# CHECK: for (int y +# CHECK: A_1 = A_1 + 1; +# CHECK: } +# CHECK: A[0] = A_1; +# CHECK: } +# CHECK: } +# CHECK: })IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// In this case the conditional access must be hoisted by two loops, there are +// two accesses here one is unhidden and the other isnt. A[0] can be +// registerized but B[0] cannot. +void testRegisterizerHiddenAccessMultiLoop() { + KernelScope kernel_scope; + BufHandle a("A", {10}, kInt); + BufHandle b("B", {10}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + Stmt* stmt = Block::make({Cond::make( + CompareSelect::make(x, 2, CompareSelectOperation::kEQ), + Block::make( + {Store::make(a, {0}, 0, 1), + For::make( + x, + 0, + 10, + For::make( + y, + 0, + 10, + Block::make({Cond::make( + CompareSelect::make(y, 3, CompareSelectOperation::kEQ), + Block::make( + {Store::make( + a, {0}, Add::make(Load::make(a, {0}, 1), 1), 1), + Store::make( + b, + {0}, + Add::make(Load::make(b, {0}, 1), 1), + 1)}), + nullptr)})))}), + nullptr)}); + + /* + * if (x==2 ? 1 : 0) { + * A[0] = 0; + * for (int x = 0; x < 10; x++) { + * for (int y = 0; y < 10; y++) { + * if (y==3 ? 1 : 0) { + * A[0] = (A[0]) + 1; + * B[0] = (B[0]) + 1; + * } + * } + * } + * } + */ + + stmt = registerize(stmt); + + /* + * if (x==2 ? 1 : 0) { + * int A_1 = 0; + * for (int x = 0; x < 10; x++) { + * for (int y = 0; y < 10; y++) { + * if (y==3 ? 1 : 0) { + * A_1 = A_1 + 1; + * B[0] = (B[0]) + 1; + * } + * } + * } + * A[0] = A_1; + * } + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: if (x==2 +# CHECK: int A_1 = 0; +# CHECK: for (int x +# CHECK: for (int y +# CHECK: if (y==3 +# CHECK: A_1 = A_1 + 1; +# CHECK: B[0] = (B[0]) + 1; +# CHECK: } +# CHECK: } +# CHECK: } +# CHECK: A[0] = A_1; +# CHECK: })IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Accesses are registerized inside two conditions, but the immeidate parent is +// not a condition. +void testRegisterizerTwoConditionalLoops() { + KernelScope kernel_scope; + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + Stmt* stmt = Block::make( + {Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + For::make( + x, + 0, + 10, + Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), 1), 1)), + nullptr), + Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kGT), + For::make( + x, + 0, + 10, + Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), 1), 1)), + nullptr)}); + + /* + * if (x<5 ? 1 : 0) { + * for (int x = 0; x < 10; x++) { + * A[0] = (A[0]) + 1; + * } + * } + * if (x>5 ? 1 : 0) { + * for (int x = 0; x < 10; x++) { + * A[0] = (A[0]) + 1; + * } + * } + */ + + stmt = registerize(stmt); + + /* + * if (x<5 ? 1 : 0) { + * int A_1 = A[0]; + * for (int x = 0; x < 10; x++) { + * A_1 = A_1 + 1; + * } + * A[0] = A_1; + * } + * if (x>5 ? 1 : 0) { + * int A_2 = A[0]; + * for (int x = 0; x < 10; x++) { + * A_2 = A_2 + 1; + * } + * A[0] = A_2; + * } + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: if (x<5 +# CHECK: int A_1 = A[0]; +# CHECK: for (int x +# CHECK: A_1 = A_1 + 1; +# CHECK: } +# CHECK: A[0] = A_1; +# CHECK: } +# CHECK: if (x>5 +# CHECK: int A_2 = A[0]; +# CHECK: for (int x +# CHECK: A_2 = A_2 + 1; +# CHECK: } +# CHECK: A[0] = A_2; +# CHECK: })IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Accesses are registerized inside two conditions, cut in the middle. +void testRegisterizerTwoConditionalLoopsCut() { + KernelScope kernel_scope; + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + Stmt* stmt = Block::make( + {Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + For::make( + x, + 0, + 10, + Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), 1), 1)), + nullptr), + For::make(x, 0, 10, Store::make(a, {x}, 1, 1)), + Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kGT), + For::make( + x, + 0, + 10, + Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), 1), 1)), + nullptr)}); + + /* + * if (x<5 ? 1 : 0) { + * for (int x = 0; x < 10; x++) { + * A[0] = (A[0]) + 1; + * } + * } + * for (int x = 0; x < 10; x++) { + * A[x] = 1; + * } + * if (x>5 ? 1 : 0) { + * for (int x = 0; x < 10; x++) { + * A[0] = (A[0]) + 1; + * } + * } + */ + + stmt = registerize(stmt); + + /* + * if (x<5 ? 1 : 0) { + * int A_1 = A[0]; + * for (int x = 0; x < 10; x++) { + * A_1 = A_1 + 1; + * } + * A[0] = A_1; + * } + * for (int x = 0; x < 10; x++) { + * A[x] = 1; + * } + * if (x>5 ? 1 : 0) { + * int A_2 = A[0]; + * for (int x = 0; x < 10; x++) { + * A_2 = A_2 + 1; + * } + * A[0] = A_2; + * } + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: if (x<5 +# CHECK: int A_1 = A[0]; +# CHECK: for (int x +# CHECK: A_1 = A_1 + 1; +# CHECK: } +# CHECK: A[0] = A_1; +# CHECK: } +# CHECK: for (int x +# CHECK: A[x] = 1; +# CHECK: if (x>5 +# CHECK: int A_2 = A[0]; +# CHECK: for (int x +# CHECK: A_2 = A_2 + 1; +# CHECK: } +# CHECK: A[0] = A_2; +# CHECK: })IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// references a Let var in a local scope which cannot be hoisted out of the +// loop. +void testRegisterizerLoopLetVar() { + KernelScope kernel_scope; + BufHandle a("A", {10}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + Stmt* stmt = Block::make({For::make( + x, + 0, + 10, + Block::make( + {Let::make(y, 30), + Store::make(a, {y}, Add::make(x, Load::make(a, {y}, 1)), 1)}))}); + + /* + * for (int x = 0; x < 10; x++) { + * int y = 30; + * A[y] = x + (A[y]); + * } + */ + + std::ostringstream before; + before << *stmt; + + // No change. + stmt = registerize(stmt); + + std::ostringstream after; + after << *stmt; + + ASSERT_EQ(before.str(), after.str()); +} + +// references a Let var in an outer scope that does not prevent hoisting the +// initializer. +void testRegisterizerLoopLetVarOuter() { + KernelScope kernel_scope; + BufHandle a("A", {10}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + Stmt* stmt = + Block::make({Let::make(y, 30), + For::make( + x, + 0, + 10, + Block::make({Store::make( + a, {y}, Add::make(x, Load::make(a, {y}, 1)), 1)}))}); + + /* + * int y = 30; + * for (int x = 0; x < 10; x++) { + * A[y] = x + (A[y]); + * } + */ + + stmt = registerize(stmt); + + /* + * int y = 30; + * int A_1 = A[y]; + * for (int x = 0; x < 10; x++) { + * A_1 = x + A_1; + * } + * A[y] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int y = 30; +# CHECK: int A_1 = A[y]; +# CHECK: for (int x +# CHECK: A_1 = x + A_1; +# CHECK: A[y] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Okay so the registerizer generally goes after index flattening, but just in +// case. Test multi index registerization. +void testRegisterizerMultiDim() { + KernelScope kernel_scope; + BufHandle a("A", {3, 4, 5}, kInt); + VarHandle x("x", kInt); + Stmt* stmt = Block::make( + {Store::make(a, {0, 1, 2}, 0, 1), + For::make( + x, + 0, + 10, + Block::make({Store::make( + a, {0, 1, 2}, Add::make(Load::make(a, {0, 1, 2}, 1), x), 1)}))}); + + /* + * A[0, 1, 2] = 0; + * for (int x = 0; x < 10; x++) { + * A[0, 1, 2] = (A[0, 1, 2]) + x; + * } + */ + + stmt = registerize(stmt); + + /* + * int A_1 = 0; + * for (int x = 0; x < 10; x++) { + * A_1 = x + A_1; + * } + * A[0, 1, 2] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = 0; +# CHECK: for (int x = 0; x < 10; x++) +# CHECK-NOT: A[ +# CHECK: A_1 = +# CHECK: A[0, 1, 2] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Wont registerize if only some dims match, but will still registerize distinct +// elements. +void testRegisterizerMultiDimPartial() { + KernelScope kernel_scope; + BufHandle a("A", {3, 4, 5}, kInt); + VarHandle x("x", kInt); + Stmt* stmt = Block::make( + {Store::make(a, {0, 1, 2}, 0, 1), + For::make( + x, + 0, + 10, + Block::make({Store::make( + a, {0, 2, 2}, Add::make(Load::make(a, {0, 1, 4}, 1), x), 1)}))}); + + /* + * A[0, 1, 2] = 0; + * for (int x = 0; x < 10; x++) { + * A[0, 2, 2] = (A[0, 1, 4]) + x; + * } + */ + + stmt = registerize(stmt); + + /* + * A[0, 1, 2] = 0; + * int A_1 = A[0, 1, 4]; + * int A_2 = A[0, 2, 2]; + * for (int x = 0; x < 10; x++) { + * A_2 = x + A_1; + * } + * A[0, 2, 2] = A_2; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: A[0, 1, 2] = 0; +# CHECK: int A_1 = A[0, 1, 4]; +# CHECK: int A_2 = A[0, 2, 2]; +# CHECK: for ( +# CHECK: A_2 = x + A_1; +# CHECK: A[0, 2, 2] = A_2;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// If they could overlap across all dimensions we cannot registerize. +void testRegisterizerMultiDimOverlap() { + KernelScope kernel_scope; + BufHandle a("A", {3, 4, 5}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + Stmt* stmt = Block::make( + {Store::make(a, {0, 1, 2}, 0, 1), + For::make( + x, + 0, + 10, + Block::make({Store::make( + a, {0, x, 2}, Add::make(Load::make(a, {y, 2, 2}, 1), x), 1)}))}); + + /* + * A[0, 1, 2] = 0; + * for (int x = 0; x < 10; x++) { + * A[0, x, 2] = (A[y, 2, 2]) + x; + * } + */ + + std::ostringstream before; + before << *stmt; + + // No change. + stmt = registerize(stmt); + + std::ostringstream after; + after << *stmt; ASSERT_EQ(before.str(), after.str()); } +// But, if one dimension is known to be distinct they do not overlap. +void testRegisterizerMultiDimPartialOverlap() { + KernelScope kernel_scope; + BufHandle a("A", {3, 4, 5}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + Stmt* stmt = Block::make( + {Store::make(a, {0, 1, 2}, 0, 1), + For::make( + x, + 0, + 10, + Block::make({Store::make( + a, {0, x, 2}, Add::make(Load::make(a, {y, 2, 4}, 1), x), 1)}))}); + + /* + * A[0, 1, 2] = 0; <---- 2nd dim overlaps with store. + * for (int x = 0; x < 10; x++) { + * A[0, x, 2] = (A[y, 2, 4]) + x; <---- 3rd dim has constant diff. + * } + */ + + stmt = registerize(stmt); + + /* + * A[0, 1, 2] = 0; + * int A_1 = A[y, 2, 4]; + * for (int x = 0; x < 10; x++) { + * A[0, x, 2] = A_1 + x; + * } + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: A[0, 1, 2] = 0; +# CHECK: int A_1 = A[y, 2, 4]; +# CHECK: for ( +# CHECK: A[0, x, 2] = A_1 + x; +# CHECK: })IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// A 3D reduction with different input dimensionality. +void testRegisterizerMultiDim3DReduction1() { + KernelScope kernel_scope; + BufHandle a("A", {10}, kInt); + BufHandle b("B", {10, 10}, kInt); + BufHandle c("C", {10, 10, 10}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + VarHandle z("z", kInt); + Stmt* stmt = For::make( + x, + 0, + 10, + For::make( + y, + 0, + 10, + For::make( + z, + 0, + 10, + Store::make( + c, + {x, y, z}, + Add::make( + Load::make(c, {x, y, z}, 1), + Mul::make( + Load::make(b, {x, y}, 1), Load::make(a, {x}, 1))), + 1)))); + + /* + * for (int x = 0; x < 10; x++) { + * for (int y = 0; y < 10; y++) { + * for (int z = 0; z < 10; z++) { + * C[x, y, z] = (C[x, y, z]) + (B[x, y]) * (A[x]); + * } + * } + * } + */ + + // We can registerize the A and B access since they can be hoisted before + // hitting a dependent loop var. + + stmt = registerize(stmt); + + /* + * for (int x = 0; x < 10; x++) { + * int A_1 = A[x]; + * for (int y = 0; y < 10; y++) { + * int B_1 = B[x, y]; + * for (int z = 0; z < 10; z++) { + * C[x, y, z] = A_1 * B_1 + (C[x, y, z]); + * } + * } + * } + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: for (int x +# CHECK: int A_1 = A[x]; +# CHECK: for (int y +# CHECK: int B_1 = B[x, y]; +# CHECK: for (int z +# CHECK: C[x, y, z] = A_1 * B_1 + (C[x, y, z]); +# CHECK: })IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// A 3D reduction with the same smaller dimensionality using different loop +// vars. +void testRegisterizerMultiDim3DReduction2() { + KernelScope kernel_scope; + BufHandle a("A", {10}, kInt); + BufHandle b("B", {10}, kInt); + BufHandle c("C", {10}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + VarHandle z("z", kInt); + Stmt* stmt = For::make( + x, + 0, + 10, + For::make( + y, + 0, + 10, + For::make( + z, + 0, + 10, + Store::make( + c, + {x}, + Add::make( + Load::make(c, {x}, 1), + Mul::make(Load::make(b, {y}, 1), Load::make(a, {x}, 1))), + 1)))); + + /* + * for (int x = 0; x < 10; x++) { + * for (int y = 0; y < 10; y++) { + * for (int z = 0; z < 10; z++) { + * C[x] = (C[x]) + (B[y]) * (A[x]); + * } + * } + * } + */ + + // We can registerize all accesses, the A and C access can be hoisted to the + // outer loop since they depend only on it's loop var while the B can only be + // raised to the loop of y. + + stmt = registerize(stmt); + + /* + * for (int x = 0; x < 10; x++) { + * int A_1 = A[x]; + * int C_1 = C[x]; + * for (int y = 0; y < 10; y++) { + * int B_1 = B[y]; + * for (int z = 0; z < 10; z++) { + * C_1 = B_1 * A_1 + C_1; + * } + * } + * C[x] = C_1; + * } + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: for (int x +# CHECK: int A_1 = A[x]; +# CHECK: int C_1 = C[x]; +# CHECK: for (int y +# CHECK: int B_1 = B[y]; +# CHECK: for (int z +# CHECK: C_1 = B_1 * A_1 + C_1; +# CHECK: } +# CHECK: } +# CHECK: C[x] = C_1; +# CHECK: })IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + } // namespace jit } // namespace torch diff --git a/test/cpp/tensorexpr/tests.h b/test/cpp/tensorexpr/tests.h index 4337c14fe3eb..c32183aaa042 100644 --- a/test/cpp/tensorexpr/tests.h +++ b/test/cpp/tensorexpr/tests.h @@ -9,296 +9,348 @@ namespace torch { namespace jit { -#define TH_FORALL_TENSOREXPR_TESTS(_) \ - _(ExprBasicValueTest) \ - _(ExprBasicValueTest02) \ - _(ExprLetTest01) \ - _(ExprLetStmtTest01) \ - _(ExprLetTest02) \ - _(ExprIntTest) \ - _(ExprFloatTest) \ - _(ExprByteTest) \ - _(ExprCharTest) \ - _(ExprShortTest) \ - _(ExprLongTest) \ - _(ExprHalfTest) \ - _(ExprDoubleTest) \ - _(ExprDisallowBoolArithmetic) \ - _(ExprVectorAdd01) \ - _(ExprCompareSelectEQ) \ - _(ExprCompareSelectDtypes) \ - _(ExprIntrinsicsDtypes) \ - _(ExprSubstitute01) \ - _(ExprMath01) \ - _(ExprUnaryMath01) \ - _(ExprBinaryMath01) \ - _(ExprDynamicShapeAdd) \ - _(ExprBitwiseOps) \ - _(IRPrinterBasicValueTest) \ - _(IRPrinterBasicValueTest02) \ - _(IRPrinterCastTest) \ - _(IRPrinterFunctionName) \ - _(ExprSimple01) \ - _(ExprLower01) \ - _(ExprSimple02) \ - _(ExprSliceHead) \ - _(ExprSliceHeadWhenFactorEqualsSize) \ - _(ExprSliceHeadWhenFactorLargerThanSize) \ - _(ExprSliceHeadWithLoopOptions) \ - _(ExprSliceHeadWithNonZeroStart) \ - _(ExprSliceTail) \ - _(ExprSliceTailWhenFactorEqualsSize) \ - _(ExprSliceTailWhenFactorLargerThanSize) \ - _(ExprSliceTailWithLoopOptions) \ - _(ExprSliceAndNormalize) \ - _(ExprSliceWithVariableDimension) \ - _(ExprSplitAndSlice) \ - _(ExprSplitWithTail) \ - _(ExprSplitWithTailNone) \ - _(ExprSplitWithMask01) \ - _(ExprSplitWithMaskRepeatedNoMask) \ - _(SplitWithTailWithLoopOptions) \ - _(SplitWithMaskWithLoopOptions) \ - _(ScheduleBroadcastAddBuffer) \ - _(ScheduleFunctionCall01) \ - _(ScheduleInlineSimple) \ - _(ScheduleInlineFunc01) \ - _(ScheduleInlineRandom) \ - _(ScheduleInlineRandomUnrelated) \ - _(ScheduleInlineRandomLowerDimensions) \ - _(ScheduleInlineIntrinsics) \ - _(ScheduleInlineRandWithIntrinsics) \ - _(ScheduleSplitAThenInline) \ - _(ScheduleSplitBThenInline) \ - _(ScheduleSplitTwiceThenInline) \ - _(ScheduleInlineThenSplit) \ - _(ScheduleSplitInlineThenSplit) \ - _(ScheduleSplitInlineSimplify) \ - _(ScheduleInlineThreeMixedOnce) \ - _(ScheduleInlineThreeMixedTwice) \ - _(ScheduleInlineThreeMixedInner) \ - _(ScheduleInlineThreeMixedSplit) \ - _(ScheduleFuserStyle) \ - _(ScheduleFuserThreeArg) \ - _(ScheduleDynamicShape2D) \ - _(ReduceSum1D) \ - _(ReduceSum2D) \ - _(ReduceSum3D) \ - _(ReduceSum10D) \ - _(ReduceProduct) \ - _(ReduceMax) \ - _(ReduceMinCustomInitializer) \ - _(ReduceAnyAll) \ - _(ReduceMatmul2D) \ - _(ReduceRfactorLike) \ - _(ReduceAsProducer) \ - _(ReduceAsConsumer) \ - _(SplitReduceAxis) \ - _(SplitNonReduceAxis) \ - _(ReorderedReductionInitializer) \ - _(ReduceRfactor) \ - _(Reduce3DRfactorInternal) \ - _(Reduce3DRfactorInner) \ - _(Reduce3DRfactorOuter) \ - _(Reduce3DRfactorWithOuter) \ - _(Reduce3DRfactorRepeated) \ - _(ReduceRfactorInsertionPoint) \ - _(Reduce3DRfactorInsertionPoint) \ - _(ReduceRepeatedInternalRfactor) \ - _(ReduceSplitTail) \ - _(ReduceSplitNoTail) \ - _(ReduceOverSplitTail) \ - _(ReduceSplitMask) \ - _(ReduceSplitNoMask) \ - _(ReduceOverSplitMask) \ - _(ReduceSplitRfactor) \ - _(ReduceOverSplitRfactor) \ - _(ReduceInlineReduction) \ - _(ReduceInlineConsumer) \ - _(ReduceInlineReducerInternal) \ - _(TypeTest01) \ - _(TypePropagation) \ - _(Cond01) \ - _(IfThenElse01) \ - _(IfThenElse02) \ - _(IfThenElse03) \ - _(ATen_cast_Float) \ - _(ATennegInt) \ - _(ATennegFloat) \ - _(ATenaddInt) \ - _(ATenaddFloat) \ - _(ATensubInt) \ - _(ATensubFloat) \ - _(ATenlerp) \ - _(ATenaddcmulInt) \ - _(ATenaddcmulFloat) \ - _(ATenmulInt) \ - _(ATenmulFloat) \ - _(ATendivInt) \ - _(ATendivFloat) \ - _(ATenmaxInt) \ - _(ATenmaxFloat) \ - _(ATenminInt) \ - _(ATenminFloat) \ - _(ATenreciprocal) \ - _(ATenreluInt) \ - _(ATenreluFloat) \ - _(ATenlogFloat) \ - _(ATenlog10Float) \ - _(ATenlog2Float) \ - _(ATenexpFloat) \ - _(ATenerfFloat) \ - _(ATencosFloat) \ - _(ATeneqInt) \ - _(ATengeInt) \ - _(ATengtInt) \ - _(ATenleInt) \ - _(ATenltInt) \ - _(ConstantFoldSimple) \ - _(ConstantFoldTwoLayer) \ - _(ConstantFoldShifts) \ - _(ConstantFoldBitwise) \ - _(ConstantFoldMultiOp) \ - _(ConstantFoldMinMax) \ - _(ConstantFoldIntrinsics) \ - _(ConstantFoldCastToBool) \ - _(ConstantFoldWithVar) \ - _(ConditionalSelectFoldSimple) \ - _(ConditionalSelectFoldTwoLayer) \ - _(ConditionalSelectFoldWithVar) \ - _(UnFoldableExpr) \ - _(HashSimple) \ - _(HashEquivalence) \ - _(HashEquivalenceRand) \ - _(HashEquivalenceAfterFolding) \ - _(HashDifferenceTypes) \ - _(HashLargeExpression) \ - _(HashForLoopOptions) \ - _(SimplifyAdd) \ - _(SimplifySub) \ - _(SimplifyMultiLayer) \ - _(SimplifyMultiTerm) \ - _(SimplifyCasts) \ - _(SimplifyEliminatesNoOps) \ - _(SimplifyMultiVar) \ - _(SimplifyEliminatesVar) \ - _(SimplifyAdds) \ - _(SimplifyMuls) \ - _(SimplifySubs) \ - _(SimplifyDiv) \ - _(SimplifyMultiOp) \ - _(SimplifyManyOps) \ - _(SimplifyFactorization) \ - _(SimplifyFactorizeUneven) \ - _(SimplifyDeeperTerms) \ - _(SimplifyDeeperDifference) \ - _(SimplifyFoldComplexDifference) \ - _(SimplifyIfComponents) \ - _(SimplifyOpaqueTerms) \ - _(SimplifySymbolicMinMax) \ - _(SimplifyNestedMax) \ - _(SimplifyNestedMin) \ - _(SimplifyWontReorderFloat) \ - _(SimplifyRoundModPattern) \ - _(SimplifyRoundModPatternFactorization) \ - _(SimplifyRoundModPatternMultivar) \ - _(SimplifyDivisionScalarFactorization) \ - _(SimplifyConstantBranches) \ - _(SimplifyConstantCond) \ - _(SimplifyEliminateEmptyCond) \ - _(SimplifyEliminateZeroLengthFor) \ - _(SimplifyOneLoopFor) \ - _(SimplifyForWontLoseLoopOptions) \ - _(SimplifyMultilevelFor) \ - _(SimplifyForCleansUp) \ - _(SimplifyEliminateEmptyFor) \ - _(SimplifyFlattenBlock) \ - _(SimplifyEliminateZeroLengthAlloc) \ - _(DontSimplifyRand) \ - _(SimplifyReorderForCond) \ - _(SimplifyFuseConditions) \ - _(SimplifySyncThreads) \ - _(SimplifyRampSubBroadcast) \ - _(SimplifyBroadcastTermExpander) \ - _(RegisterizerSimple) \ - _(RegisterizerLoop) \ - _(RegisterizerLoopFixedLoad) \ - _(RegisterizerMultiVar) \ - _(RegisterizerVariableLoad) \ - _(RegisterizerSymbolicIndices) \ - _(RegisterizerEarlyStop) \ - _(RegisterizerMultiLoop) \ - _(RegisterizerRepeated) \ - _(RegisterizerNoLoads) \ - _(RegisterizerNoRepeatedStores) \ - _(RegisterizerMultiVarOverlap) \ - _(RegisterizerAllocs) \ - _(RegisterizerNoInitializer) \ - _(RegisterizerLoadThenStore) \ - _(RegisterizerParallelized) \ - _(RegisterizerConditions) \ - _(StmtClone) \ - _(BoundsInference_1) \ - _(BoundsInference_2) \ - _(BoundsInference_3) \ - _(BoundsInference_4) \ - _(BoundsInference_5) \ - _(BoundsInference_6) \ - _(BoundsInferenceNonOverlapping) \ - _(BoundsInferenceAdjacent) \ - _(MergeInferredBounds) \ - _(MergeInferredLoadStoreDiff) \ - _(MergeInferred2DBounds) \ - _(MergeAdjacentBounds) \ - _(MergeSymbolicBounds) \ - _(MergeSymbolicAdjacent) \ - _(LoopNestComputeAt_1) \ - _(LoopNestComputeAt_2) \ - _(LoopNestComputeAt_3) \ - _(LoopNestComputeAt_4) \ - _(LoopNestReorderAxis1) \ - _(LoopNestReorderPartialAxes) \ - _(LoopNestReorderInternalAxis) \ - _(LoopNestReorderEnclosingAxis) \ - _(LoopNestReorderSameAxis) \ - _(LoopNestReorderExtraStatements) \ - _(LoopNestReorderLongStringOfPreOrphans) \ - _(LoopNestReorderLongStringOfPostOrphans) \ - _(LoopNestReorderLongStringFull) \ - _(LoopNestReorderInternalLoopNest) \ - _(OuterLoopVectorization) \ - _(Unroll) \ - _(UnrollOuter) \ - _(UnrollInner) \ - _(UnrollMultipleStatements) \ - _(UnrollEmpty) \ - _(NoUnroll) \ - _(UnrollWithLet) \ - _(NormalizeStartPositive) \ - _(NormalizeStartNegative) \ - _(NormalizeStartZero) \ - _(NormalizeStartVariable) \ - _(NormalizeOnNestedOuterLoop) \ - _(NormalizeOnNestedInnerLoop) \ - _(NormalizeAndSplitWithTail) \ - _(DetectInlineRankMismatch) \ - _(Kernel_1) \ - _(Kernel_2) \ - _(Kernel_3) \ - _(Kernel_4) \ - _(KernelSumAllAxes) \ - _(KernelSumOneAxis) \ - _(KernelSumMultipleAxes) \ - _(FuserPass_1) \ - _(FuserPass_2) \ - _(FuserPass_3) \ - _(FuserPass_0DimInput) \ - _(FuserPass_UnfusibleDevice) \ - _(FuserPass_UnknownShapes) \ - _(FuserPass_UnknownShapesIgnored) \ - _(FuserPass_Multidevice) \ - _(FuserPass_MergeGroups) \ - _(TrainBasic) \ +#define TH_FORALL_TENSOREXPR_TESTS(_) \ + _(ExprBasicValueTest) \ + _(ExprBasicValueTest02) \ + _(ExprLetTest01) \ + _(ExprLetStmtTest01) \ + _(ExprLetTest02) \ + _(ExprIntTest) \ + _(ExprFloatTest) \ + _(ExprByteTest) \ + _(ExprCharTest) \ + _(ExprShortTest) \ + _(ExprLongTest) \ + _(ExprHalfTest) \ + _(ExprDoubleTest) \ + _(ExprDisallowBoolArithmetic) \ + _(ExprVectorAdd01) \ + _(ExprCompareSelectEQ) \ + _(ExprCompareSelectDtypes) \ + _(ExprIntrinsicsDtypes) \ + _(ExprSubstitute01) \ + _(ExprMath01) \ + _(ExprUnaryMath01) \ + _(ExprBinaryMath01) \ + _(ExprDynamicShapeAdd) \ + _(ExprBitwiseOps) \ + _(IRPrinterBasicValueTest) \ + _(IRPrinterBasicValueTest02) \ + _(IRPrinterCastTest) \ + _(IRPrinterFunctionName) \ + _(ExprSimple01) \ + _(ExprLower01) \ + _(ExprSimple02) \ + _(ExprSliceHead) \ + _(ExprSliceHeadWhenFactorEqualsSize) \ + _(ExprSliceHeadWhenFactorLargerThanSize) \ + _(ExprSliceHeadWithLoopOptions) \ + _(ExprSliceHeadWithNonZeroStart) \ + _(ExprSliceTail) \ + _(ExprSliceTailWhenFactorEqualsSize) \ + _(ExprSliceTailWhenFactorLargerThanSize) \ + _(ExprSliceTailWithLoopOptions) \ + _(ExprSliceAndNormalize) \ + _(ExprSliceWithVariableDimension) \ + _(ExprSplitAndSlice) \ + _(ExprSplitWithTail) \ + _(ExprSplitWithTailNone) \ + _(ExprSplitWithMask01) \ + _(ExprSplitWithMaskRepeatedNoMask) \ + _(SplitWithTailWithLoopOptions) \ + _(SplitWithMaskWithLoopOptions) \ + _(ScheduleBroadcastAddBuffer) \ + _(ScheduleFunctionCall01) \ + _(ScheduleInlineSimple) \ + _(ScheduleInlineFunc01) \ + _(ScheduleInlineRandom) \ + _(ScheduleInlineRandomUnrelated) \ + _(ScheduleInlineRandomLowerDimensions) \ + _(ScheduleInlineIntrinsics) \ + _(ScheduleInlineRandWithIntrinsics) \ + _(ScheduleSplitAThenInline) \ + _(ScheduleSplitBThenInline) \ + _(ScheduleSplitTwiceThenInline) \ + _(ScheduleInlineThenSplit) \ + _(ScheduleSplitInlineThenSplit) \ + _(ScheduleSplitInlineSimplify) \ + _(ScheduleInlineThreeMixedOnce) \ + _(ScheduleInlineThreeMixedTwice) \ + _(ScheduleInlineThreeMixedInner) \ + _(ScheduleInlineThreeMixedSplit) \ + _(ScheduleFuserStyle) \ + _(ScheduleFuserThreeArg) \ + _(ScheduleDynamicShape2D) \ + _(ReduceSum1D) \ + _(ReduceSum2D) \ + _(ReduceSum3D) \ + _(ReduceSum10D) \ + _(ReduceProduct) \ + _(ReduceMax) \ + _(ReduceMinCustomInitializer) \ + _(ReduceAnyAll) \ + _(ReduceMatmul2D) \ + _(ReduceRfactorLike) \ + _(ReduceAsProducer) \ + _(ReduceAsConsumer) \ + _(SplitReduceAxis) \ + _(SplitNonReduceAxis) \ + _(ReorderedReductionInitializer) \ + _(ReduceRfactor) \ + _(Reduce3DRfactorInternal) \ + _(Reduce3DRfactorInner) \ + _(Reduce3DRfactorOuter) \ + _(Reduce3DRfactorWithOuter) \ + _(Reduce3DRfactorRepeated) \ + _(ReduceRfactorInsertionPoint) \ + _(Reduce3DRfactorInsertionPoint) \ + _(ReduceRepeatedInternalRfactor) \ + _(ReduceSplitTail) \ + _(ReduceSplitNoTail) \ + _(ReduceOverSplitTail) \ + _(ReduceSplitMask) \ + _(ReduceSplitNoMask) \ + _(ReduceOverSplitMask) \ + _(ReduceSplitRfactor) \ + _(ReduceOverSplitRfactor) \ + _(ReduceInlineReduction) \ + _(ReduceInlineConsumer) \ + _(ReduceInlineReducerInternal) \ + _(TypeTest01) \ + _(TypePropagation) \ + _(Cond01) \ + _(IfThenElse01) \ + _(IfThenElse02) \ + _(IfThenElse03) \ + _(ATen_cast_Float) \ + _(ATennegInt) \ + _(ATennegFloat) \ + _(ATenaddInt) \ + _(ATenaddFloat) \ + _(ATensubInt) \ + _(ATensubFloat) \ + _(ATenlerp) \ + _(ATenaddcmulInt) \ + _(ATenaddcmulFloat) \ + _(ATenmulInt) \ + _(ATenmulFloat) \ + _(ATendivInt) \ + _(ATendivFloat) \ + _(ATenmaxInt) \ + _(ATenmaxFloat) \ + _(ATenminInt) \ + _(ATenminFloat) \ + _(ATenreciprocal) \ + _(ATenreluInt) \ + _(ATenreluFloat) \ + _(ATenlogFloat) \ + _(ATenlog10Float) \ + _(ATenlog2Float) \ + _(ATenexpFloat) \ + _(ATenerfFloat) \ + _(ATencosFloat) \ + _(ATeneqInt) \ + _(ATengeInt) \ + _(ATengtInt) \ + _(ATenleInt) \ + _(ATenltInt) \ + _(ConstantFoldSimple) \ + _(ConstantFoldTwoLayer) \ + _(ConstantFoldShifts) \ + _(ConstantFoldBitwise) \ + _(ConstantFoldMultiOp) \ + _(ConstantFoldMinMax) \ + _(ConstantFoldIntrinsics) \ + _(ConstantFoldCastToBool) \ + _(ConstantFoldWithVar) \ + _(ConditionalSelectFoldSimple) \ + _(ConditionalSelectFoldTwoLayer) \ + _(ConditionalSelectFoldWithVar) \ + _(UnFoldableExpr) \ + _(HashSimple) \ + _(HashEquivalence) \ + _(HashEquivalenceRand) \ + _(HashEquivalenceAfterFolding) \ + _(HashDifferenceTypes) \ + _(HashLargeExpression) \ + _(HashForLoopOptions) \ + _(SimplifyAdd) \ + _(SimplifySub) \ + _(SimplifyMultiLayer) \ + _(SimplifyMultiTerm) \ + _(SimplifyCasts) \ + _(SimplifyEliminatesNoOps) \ + _(SimplifyMultiVar) \ + _(SimplifyEliminatesVar) \ + _(SimplifyAdds) \ + _(SimplifyMuls) \ + _(SimplifySubs) \ + _(SimplifyDiv) \ + _(SimplifyMultiOp) \ + _(SimplifyManyOps) \ + _(SimplifyFactorization) \ + _(SimplifyFactorizeUneven) \ + _(SimplifyDeeperTerms) \ + _(SimplifyDeeperDifference) \ + _(SimplifyFoldComplexDifference) \ + _(SimplifyIfComponents) \ + _(SimplifyOpaqueTerms) \ + _(SimplifySymbolicMinMax) \ + _(SimplifyNestedMax) \ + _(SimplifyNestedMin) \ + _(SimplifyWontReorderFloat) \ + _(SimplifyRoundModPattern) \ + _(SimplifyRoundModPatternFactorization) \ + _(SimplifyRoundModPatternMultivar) \ + _(SimplifyDivisionScalarFactorization) \ + _(SimplifyConstantBranches) \ + _(SimplifyConstantCond) \ + _(SimplifyEliminateEmptyCond) \ + _(SimplifyEliminateZeroLengthFor) \ + _(SimplifyOneLoopFor) \ + _(SimplifyForWontLoseLoopOptions) \ + _(SimplifyMultilevelFor) \ + _(SimplifyForCleansUp) \ + _(SimplifyEliminateEmptyFor) \ + _(SimplifyFlattenBlock) \ + _(SimplifyEliminateZeroLengthAlloc) \ + _(DontSimplifyRand) \ + _(SimplifyReorderForCond) \ + _(SimplifyFuseConditions) \ + _(SimplifySyncThreads) \ + _(SimplifyRampSubBroadcast) \ + _(SimplifyBroadcastTermExpander) \ + _(RegisterizerSimple) \ + _(RegisterizerLoop) \ + _(RegisterizerLoopFixedLoad) \ + _(RegisterizerLoopInternal) \ + _(RegisterizerLoopInternalLoadOverlap) \ + _(RegisterizerLoopInternalRepeated) \ + _(RegisterizerLoopInternalRepeatedOverlapLoopVar) \ + _(RegisterizerLoopInternalRepeatedOverlapOther) \ + _(RegisterizerMultiVar) \ + _(RegisterizerVariableLoad) \ + _(RegisterizerSymbolicIndices) \ + _(RegisterizerMultiLoop) \ + _(RegisterizerRepeated) \ + _(RegisterizerNoLoads) \ + _(RegisterizerNoRepeatedStores) \ + _(RegisterizerMultiVarOverlap) \ + _(RegisterizerAllocs) \ + _(RegisterizerNoInitializer) \ + _(RegisterizerNoInitializerLoopVar) \ + _(RegisterizerLoadThenStore) \ + _(RegisterizerParallelized) \ + _(RegisterizerConditionAfter) \ + _(RegisterizerConditionBefore) \ + _(RegisterizerConditionInside) \ + _(RegisterizerConditionInsideOverlap1) \ + _(RegisterizerConditionInsideOverlap2) \ + _(RegisterizerConditionHidden) \ + _(RegisterizerConditionUnhidden) \ + _(RegisterizerCondCondition) \ + _(RegisterizerCondConditionUnhidden) \ + _(RegisterizerIfThenElseHidden) \ + _(RegisterizerIfThenElseUnhidden) \ + _(RegisterizerIfThenElseNested) \ + _(RegisterizerIfThenElseInternal) \ + _(RegisterizerIfThenElseCondition) \ + _(RegisterizerIfThenElseConditionUnhidden) \ + _(RegisterizerConditionBranchOnly) \ + _(RegisterizerCondIfThenElse) \ + _(RegisterizerIfThenElseLoop) \ + _(RegisterizerIfThenElseLoopCut) \ + _(RegisterizerPartialAfter) \ + _(RegisterizerPartialBefore) \ + _(RegisterizerPartialInside) \ + _(RegisterizerPartialCondition) \ + _(RegisterizerPartialConditionInternalCut) \ + _(RegisterizerPartialConditionInternalStart) \ + _(RegisterizerPartialOverlapsTwo) \ + _(RegisterizerNestedBlocks) \ + _(RegisterizerNestedConditions) \ + _(RegisterizerNestedConditionsUnhidden) \ + _(RegisterizerNestedConditionsHiddenFirst) \ + _(RegisterizerNestedConditionsHiddenSecond) \ + _(RegisterizerNestedConditionsCut) \ + _(RegisterizerNestedConditionLoopHidden) \ + _(RegisterizerNestedConditionThreeDeep) \ + _(RegisterizerNestedLoopSimple) \ + _(RegisterizerHiddenAccessYes) \ + _(RegisterizerHiddenAccessNo) \ + _(RegisterizerHiddenAccessMultiLoop) \ + _(RegisterizerTwoConditionalLoops) \ + _(RegisterizerTwoConditionalLoopsCut) \ + _(RegisterizerLoopLetVar) \ + _(RegisterizerLoopLetVarOuter) \ + _(RegisterizerMultiDim) \ + _(RegisterizerMultiDimPartial) \ + _(RegisterizerMultiDimOverlap) \ + _(RegisterizerMultiDimPartialOverlap) \ + _(RegisterizerMultiDim3DReduction1) \ + _(RegisterizerMultiDim3DReduction2) \ + _(StmtClone) \ + _(BoundsInference_1) \ + _(BoundsInference_2) \ + _(BoundsInference_3) \ + _(BoundsInference_4) \ + _(BoundsInference_5) \ + _(BoundsInference_6) \ + _(BoundsInferenceNonOverlapping) \ + _(BoundsInferenceAdjacent) \ + _(MergeInferredBounds) \ + _(MergeInferredLoadStoreDiff) \ + _(MergeInferred2DBounds) \ + _(MergeAdjacentBounds) \ + _(MergeSymbolicBounds) \ + _(MergeSymbolicAdjacent) \ + _(LoopNestComputeAt_1) \ + _(LoopNestComputeAt_2) \ + _(LoopNestComputeAt_3) \ + _(LoopNestComputeAt_4) \ + _(LoopNestReorderAxis1) \ + _(LoopNestReorderPartialAxes) \ + _(LoopNestReorderInternalAxis) \ + _(LoopNestReorderEnclosingAxis) \ + _(LoopNestReorderSameAxis) \ + _(LoopNestReorderExtraStatements) \ + _(LoopNestReorderLongStringOfPreOrphans) \ + _(LoopNestReorderLongStringOfPostOrphans) \ + _(LoopNestReorderLongStringFull) \ + _(LoopNestReorderInternalLoopNest) \ + _(OuterLoopVectorization) \ + _(Unroll) \ + _(UnrollOuter) \ + _(UnrollInner) \ + _(UnrollMultipleStatements) \ + _(UnrollEmpty) \ + _(NoUnroll) \ + _(UnrollWithLet) \ + _(NormalizeStartPositive) \ + _(NormalizeStartNegative) \ + _(NormalizeStartZero) \ + _(NormalizeStartVariable) \ + _(NormalizeOnNestedOuterLoop) \ + _(NormalizeOnNestedInnerLoop) \ + _(NormalizeAndSplitWithTail) \ + _(DetectInlineRankMismatch) \ + _(Kernel_1) \ + _(Kernel_2) \ + _(Kernel_3) \ + _(Kernel_4) \ + _(KernelSumAllAxes) \ + _(KernelSumOneAxis) \ + _(KernelSumMultipleAxes) \ + _(FuserPass_1) \ + _(FuserPass_2) \ + _(FuserPass_3) \ + _(FuserPass_0DimInput) \ + _(FuserPass_UnfusibleDevice) \ + _(FuserPass_UnknownShapes) \ + _(FuserPass_UnknownShapesIgnored) \ + _(FuserPass_Multidevice) \ + _(FuserPass_MergeGroups) \ + _(TrainBasic) \ _(Conv2D) #define TH_FORALL_TENSOREXPR_TESTS_LLVM(_) \ diff --git a/torch/csrc/jit/tensorexpr/ir_simplifier.cpp b/torch/csrc/jit/tensorexpr/ir_simplifier.cpp index 37c856a2e618..1d67d5359a1d 100644 --- a/torch/csrc/jit/tensorexpr/ir_simplifier.cpp +++ b/torch/csrc/jit/tensorexpr/ir_simplifier.cpp @@ -1909,6 +1909,7 @@ Block* TermExpander::fuseConditions(Block* v) { stmts.push_back(s); continue; } + // Fuse the two Conds by appending the bodies of the second Cond to the // first. Block* true_block = new Block({}); @@ -1939,11 +1940,13 @@ Block* TermExpander::fuseConditions(Block* v) { false_block = nullptr; } - prev_cond = prev_cond->cloneWithNewBodies(true_block, false_block); + Stmt* new_cond = prev_cond->cloneWithNewBodies(true_block, false_block) + ->accept_mutator(this); + prev_cond = dynamic_cast(new_cond); // erase, which shortens the list. stmts.pop_back(); - stmts.push_back(prev_cond); + stmts.push_back(new_cond); did_anything = true; } diff --git a/torch/csrc/jit/tensorexpr/registerizer.cpp b/torch/csrc/jit/tensorexpr/registerizer.cpp index 7181e9ec134a..6e4383b09196 100644 --- a/torch/csrc/jit/tensorexpr/registerizer.cpp +++ b/torch/csrc/jit/tensorexpr/registerizer.cpp @@ -3,6 +3,181 @@ namespace torch { namespace jit { namespace tensorexpr { +namespace registerizer { + +// AccessInfo + +void AccessInfo::addStore( + const Store* store, + const std::shared_ptr& scope) { + block_ = + block_ ? Block::getSharedParent(block_, scope->block()) : scope->block(); + + // If there is already a usage and it's this store, that means the same + // access is present in the RHS. + firstUsageOverlapped_ |= first_usage_ == store; + first_usage_ = first_usage_ ? block_->getEnclosedRoot(first_usage_) : store; + last_usage_ = store; + + store_cost_ = IRSimplifier::simplify(new Add(store_cost_, new IntImm(1))); + stores_.push_back(store); + + conditionId_ = scope->conditionId(); + hiddenAccess_.reset(); +} + +void AccessInfo::addLoad( + const Load* load, + const std::shared_ptr& scope, + const Stmt* usage) { + block_ = + block_ ? Block::getSharedParent(block_, scope->block()) : scope->block(); + first_usage_ = first_usage_ ? block_->getEnclosedRoot(first_usage_) : usage; + last_usage_ = usage; + + load_cost_ = IRSimplifier::simplify(new Add(load_cost_, new IntImm(1))); + loads_.push_back(load); + + conditionId_ = scope->conditionId(); + hiddenAccess_.reset(); +} + +void AccessInfo::merge(const std::shared_ptr& other) { + TORCH_INTERNAL_ASSERT(hash_ == other->hash()); + TORCH_INTERNAL_ASSERT(indices_.size() == other->indices().size()); + + last_usage_ = other->last_usage(); + for (auto* s : other->stores()) { + stores_.push_back(s); + } + for (auto* l : other->loads()) { + loads_.push_back(l); + } + + store_cost_ = + IRSimplifier::simplify(new Add(store_cost_, other->store_cost())); + load_cost_ = IRSimplifier::simplify(new Add(load_cost_, other->load_cost())); + + block_ = Block::getSharedParent(block_, other->block()); + // update first and last usage to be in the parent Block. + first_usage_ = block_->getEnclosedRoot(first_usage_); + last_usage_ = block_->getEnclosedRoot(last_usage_); + hiddenAccess_.reset(); +} + +bool AccessInfo::overlaps(const std::shared_ptr& other) { + // All accesses to a buf must have the same dimensionality. + TORCH_INTERNAL_ASSERT(indices_.size() == other->indices().size()); + + const auto& other_indices = other->indices(); + + // They don't overlap if there is a guaranteed difference in any + // dimension. + bool overlap = true; + for (size_t i = 0; i < indices_.size(); ++i) { + const Expr* diff = new Sub(indices_[i], other_indices[i]); + diff = IRSimplifier::simplify(diff); + + if (diff->isConstant() && !immediateEquals(diff, 0)) { + overlap = false; + break; + } + } + + return overlap; +} + +bool AccessInfo::dependsOnVar(const Var* v) { + VarFinder vf; + for (auto* i : indices_) { + i->accept(&vf); + } + + return vf.vars().count(v); +} + +std::shared_ptr AccessInfo::cloneWithHiddenInfo( + const std::shared_ptr& orig) { + std::shared_ptr newInfo = std::make_shared( + orig->hash(), orig->buf(), orig->indices(), orig->accessOrder()); + + newInfo->block_ = orig->block_; + newInfo->first_usage_ = orig->first_usage_; + newInfo->last_usage_ = orig->last_usage_; + newInfo->firstUsageOverlapped_ = orig->firstUsageOverlapped_; + newInfo->store_cost_ = orig->store_cost_; + newInfo->load_cost_ = orig->load_cost_; + for (auto* s : orig->stores_) { + newInfo->stores_.push_back(s); + } + for (auto* s : orig->loads_) { + newInfo->loads_.push_back(s); + } + + newInfo->conditionId_ = orig->conditionId_; + newInfo->hiddenAccess_ = orig; + return newInfo; +} + +void AccessInfo::print() const { + std::cout << "Access: " << *buf_ << "{"; + for (auto* i : indices_) { + std::cout << *i << " "; + } + std::cout << "} stores: " << stores_.size() << " (" << *store_cost_ << ") -"; + std::cout << " loads: " << loads_.size() << " (" << *load_cost_ << ")"; + if (conditionId_) { + std::cout << " cond: " << conditionId_; + } + + std::cout << "\n"; +} + +// Scope + +void Scope::closeAccess(const std::shared_ptr& info) { + closedAccesses_.push_back(info); +} + +AccessHashMap& Scope::getAccessMapByBuf(const Buf* b) { + auto it = openAccesses_.find(b); + if (it == openAccesses_.end()) { + // create and return + return openAccesses_[b]; + } + + return it->second; +} + +void Scope::filterClosed() { + closedAccesses_.erase( + std::remove_if( + closedAccesses_.begin(), + closedAccesses_.end(), + [](auto info) { + return info->store_cost()->isConstant() && + immediateAs(info->store_cost()) <= 1 && + info->load_cost()->isConstant() && + immediateAs(info->load_cost()) <= 1; + }), + closedAccesses_.end()); +} + +// RegisterizerAnalysis + +void RegisterizerAnalysis::closeAccessIntoScope( + const std::shared_ptr& info, + const std::shared_ptr& scope) { + if (exprConditionals_.count(info->conditionId()) != 0) { + return; + } + + if (info->hiddenAccess()) { + closeAccessIntoScope(info->hiddenAccess(), scope); + return; + } + scope->closeAccess(info); +} void RegisterizerAnalysis::visit(const For* v) { if (v->loop_options().is_gpu_block_index() || @@ -11,28 +186,196 @@ void RegisterizerAnalysis::visit(const For* v) { "Registerization must occur after parallelism flattening"); } - const Expr* old_loopCost = loopCost_; - loopCost_ = IRSimplifier::simplify( - new Mul(loopCost_, new Sub(v->stop(), v->start()))); + auto parent = currentScope_; + currentScope_ = std::make_shared(v->body(), parent); + + currentScope_->addLocalVar(v->var()); + stmtStack_.push_front(v); v->body()->accept(this); stmtStack_.pop_front(); - loopCost_ = old_loopCost; + const Expr* loopExtent = + IRSimplifier::simplify(new Sub(v->stop(), v->start())); + + // now we need to see which accesses we can hoist out of the for loop, their + // costs should be multiplied by the loop extent. + for (auto& pair : currentScope_->openAccesses()) { + const Buf* buf = pair.first; + if (pair.second.empty()) { + continue; + } + + auto& childAccesses = pair.second; + + for (auto it = childAccesses.begin(); it != childAccesses.end();) { + std::shared_ptr& candidate = it->second; + + // If the access is open, but conditional, then we have a problem. It's + // possible that an access at a higher scope could "unhide" the + // conditional access, in which case we need to hoist. If there is no + // access to this element at a higher scope then we cannot safely hoist. + // We cannot know at this level whether that will or wont occur. + // + // The solution we take here is to split the space-time continuum, and + // keep both versions of the access handy. If the hoisted access is not + // used above, we'll fall back to using the hidden, conditional + // AccessInfo - if it is, we'll delete the copy. + if (candidate->conditionId() != 0) { + candidate = AccessInfo::cloneWithHiddenInfo(candidate); + } + + bool closed = false; + // If this access depends on a locally scoped variable, it cannot be + // hosted out of the loop. + for (auto* v : currentScope_->localVars()) { + if (candidate->dependsOnVar(v)) { + closeAccessIntoScope(candidate, currentScope_); + closed = true; + break; + } + } + if (closed) { + it = childAccesses.erase(it); + continue; + } + + // hoist! + // By hoisting we pull the reads and writes out of the loop, and so the + // benefit of registerizing this access is multiplied by the loop extent. + candidate->setEnclosingBlock(parent->block()); + candidate->hoistCosts(loopExtent); + + // in the parent block, this loop Stmt is the insertion point for the + // initializer and finalizer. + candidate->setUsageMarks(v, v); + + ++it; + } + } + + // If an access is closed within a loop then it cannot be merged into an + // existing open access, but will still close that existing access. This is + // somewhat different from the regular merge so we need to handle closed + // accesses first. + mergeHiddenScope(true); + + // having hoisted, now we can merge normally. + mergeCurrentScopeIntoParent(); }; +void RegisterizerAnalysis::visit(const Cond* v) { + const Expr* condition = v->condition(); + Block* true_stmt = v->true_stmt(); + Block* false_stmt = v->false_stmt(); + + stmtStack_.push_front(v); + + // condition is in the enclosing scope. + condition->accept(this); + + auto prev_scope = currentScope_; + auto true_scope = + std::make_shared(true_stmt, prev_scope, ++conditionId_); + auto false_scope = + std::make_shared(false_stmt, prev_scope, ++conditionId_); + + if (true_stmt) { + currentScope_ = true_scope; + true_stmt->accept(this); + mergeHiddenScope(true); + mergeCurrentScopeIntoParent(); + } + if (false_stmt) { + currentScope_ = false_scope; + false_stmt->accept(this); + mergeHiddenScope(true); + mergeCurrentScopeIntoParent(); + } + + // TODO: even though both scopes are conditional, we can merge accesses if + // they totally overlap in both branches, since we can guarantee one + // definition will be hit. We might need a 3-way merge? Not as simple as + // merging the true and false scopes together first. + + stmtStack_.pop_front(); +} + +// IfThenElses are just like Conds except they are not Stmts, which means no +// registerization can occur internally. However, the first reference to an +// access can occur within one if its visible outside the condition. +void RegisterizerAnalysis::visit(const IfThenElse* v) { + const Expr* condition = v->condition(); + const Expr* true_value = v->true_value(); + const Expr* false_value = v->false_value(); + + // condition is in enclosing scope. + condition->accept(this); + + auto prev_scope = currentScope_; + auto true_scope = + std::make_shared(prev_scope->block(), prev_scope, ++conditionId_); + auto false_scope = + std::make_shared(prev_scope->block(), prev_scope, ++conditionId_); + + // We store IfThenElse scopes in a global map, which we use to prevent closing + // any access that would require inserting statements in the values, which + // cannot enclose Stmts. + exprConditionals_.insert(true_scope->conditionId()); + exprConditionals_.insert(false_scope->conditionId()); + + if (true_value) { + currentScope_ = true_scope; + true_value->accept(this); + mergeHiddenScope(false); + mergeCurrentScopeIntoParent(); + } + + if (false_value) { + currentScope_ = false_scope; + false_value->accept(this); + mergeHiddenScope(false); + mergeCurrentScopeIntoParent(); + } +} + +void RegisterizerAnalysis::visit(const Let* v) { + currentScope_->addLocalVar(v->var()); + + stmtStack_.push_front(v); + v->value()->accept(this); + stmtStack_.pop_front(); +} + void RegisterizerAnalysis::visit(const Block* v) { - const Block* last = enclosingBlock_; - enclosingBlock_ = v; + auto prev_scope = currentScope_; + if (currentScope_->block() != v) { + currentScope_ = std::make_shared(v, prev_scope); + } + stmtStack_.push_front(v); - costByBlock_[v] = loopCost_; - IRVisitor::visit(v); + + for (auto* s : *v) { + s->accept(this); + if (currentScope_->block() != v) { + // merge the inner block's accesses into this Block's accesses. + mergeCurrentScopeIntoParent(); + } + } + stmtStack_.pop_front(); - enclosingBlock_ = last; + + if (prev_scope->block() == nullptr) { + // close any open candidates. + for (auto& p1 : currentScope_->openAccesses()) { + for (auto& p2 : p1.second) { + closeAccessIntoScope(p2.second, currentScope_); + } + } + } } void RegisterizerAnalysis::visit(const Store* v) { - // path into value first. stmtStack_.push_front(v); v->value()->accept(this); stmtStack_.pop_front(); @@ -42,26 +385,49 @@ void RegisterizerAnalysis::visit(const Store* v) { return; } + // hash the Store: SimplifierHashType accessHash = hasher_.hash(v->buf()); for (auto* i : v->indices()) { accessHash = hasher_.hash_combine(accessHash, i); } accessHash = hasher_.hash_combine(accessHash, v->mask()); - std::shared_ptr info; - auto candidateIt = candidates_.find(accessHash); - if (candidateIt != candidates_.end()) { - info = candidateIt->second; - } else { - info = std::make_shared(v->buf(), v->indices()); - candidates_[accessHash] = info; - encounterOrder_.push_back(info); + auto& bufAccesses = currentScope_->getAccessMapByBuf(v->buf()); + auto candidateIt = bufAccesses.find(accessHash); + + // If an identical access already exists, add this Store to it. + if (candidateIt != bufAccesses.end()) { + candidateIt->second->addStore(v, currentScope_); + return; } - if (nested_conditions_ > 0) { - info->invalid = true; + // Otherwise make a new AccessInfo and add this store. + auto info = std::make_shared( + accessHash, v->buf(), v->indices(), accessOrder_++); + info->addStore(v, currentScope_); + + // This new access may overlap an existing open access, in which case we need + // to close the older of the two. + bool alreadyOverlapped = false; + for (auto it = bufAccesses.begin(); it != bufAccesses.end();) { + auto other = it->second; + if (info->overlaps(other)) { + if (other->last_usage() == v) { + // we are already overlapped by an access in the RHS. + alreadyOverlapped = true; + } + closeAccessIntoScope(other, currentScope_); + it = bufAccesses.erase(it); + } else { + ++it; + } + } + + if (alreadyOverlapped) { + closeAccessIntoScope(info, currentScope_); + } else { + bufAccesses.emplace(accessHash, info); } - info->addStore(v, enclosingBlock_, loopCost_); } void RegisterizerAnalysis::visit(const Load* v) { @@ -69,281 +435,360 @@ void RegisterizerAnalysis::visit(const Load* v) { // already a scalar. return; } - + // hash the Load: SimplifierHashType accessHash = hasher_.hash(v->buf()); for (auto* i : v->indices()) { accessHash = hasher_.hash_combine(accessHash, i); } accessHash = hasher_.hash_combine(accessHash, v->mask()); - std::shared_ptr info; - auto candidateIt = candidates_.find(accessHash); - if (candidateIt != candidates_.end()) { - info = candidateIt->second; - } else { - info = std::make_shared(v->buf(), v->indices()); - candidates_[accessHash] = info; - encounterOrder_.push_back(info); + auto& bufAccesses = currentScope_->getAccessMapByBuf(v->buf()); + auto candidateIt = bufAccesses.find(accessHash); + if (candidateIt != bufAccesses.end()) { + // found the right access, can just insert. + candidateIt->second->addLoad(v, currentScope_, stmtStack_.front()); + return; } - if (nested_conditions_ > 0) { - info->invalid = true; + std::shared_ptr info = std::make_shared( + accessHash, v->buf(), v->indices(), accessOrder_++); + info->addLoad(v, currentScope_, stmtStack_.front()); + + bool alreadyOverlapped = false; + // This new access may overlap an existing open access, in which case we need + // to finalize the older of the two. + for (auto it = bufAccesses.begin(); it != bufAccesses.end();) { + auto other = it->second; + if (info->overlaps(other)) { + if (info->last_usage() == other->last_usage()) { + // if these two accesses are from the same Stmt, they already overlap + // each other. + alreadyOverlapped = true; + } + closeAccessIntoScope(other, currentScope_); + it = bufAccesses.erase(it); + } else { + ++it; + } } - info->addLoad(v, enclosingBlock_, loopCost_, stmtStack_.front()); + if (alreadyOverlapped) { + closeAccessIntoScope(info, currentScope_); + } else { + bufAccesses.emplace(accessHash, info); + } } -void RegisterizerAnalysis::visit(const IfThenElse* v) { - v->condition()->accept(this); - nested_conditions_++; - v->true_value()->accept(this); - v->false_value()->accept(this); - nested_conditions_--; -} +// Loop and Conditional scopes are different in that it may or may not be +// possible to hoist the intializer of a scalar variable outside the block +// depending on if we can tell that the Buffer access is valid outside. This is +// tricky because the access that demonstrates this may be later in the tree and +// we haven't encountered it yet. +// The allowClosed flag indicates whether we want to keep the closed accesses +// (For and Cond), or not (IfThenElse). +void RegisterizerAnalysis::mergeHiddenScope(bool allowClosed) { + // The rule is that if any access is closed within the conditional block, any + // accesses which overlap it must also be closed - since their initializer + // cannot be hoisted out of the block. + std::list> newClosed; + for (auto& info : currentScope_->closedAccesses()) { + auto& candidates = currentScope_->getAccessMapByBuf(info->buf()); + for (auto it = candidates.begin(); it != candidates.end();) { + std::shared_ptr candidate = it->second; + + if (info->hash() == candidate->hash() || info->overlaps(candidate)) { + newClosed.push_back(candidate); + it = candidates.erase(it); + } else { + ++it; + } + } + } -void RegisterizerAnalysis::visit(const Cond* v) { - const Expr* condition = v->condition(); - Stmt* true_stmt = v->true_stmt(); - Stmt* false_stmt = v->false_stmt(); - condition->accept(this); + if (allowClosed) { + for (auto& info : newClosed) { + closeAccessIntoScope(info, currentScope_); + } + } else { + currentScope_->closedAccesses().clear(); + } +} - stmtStack_.push_front(v); - nested_conditions_++; +// Merge currentScope_ into it's parent, and make parent the new currentScope_. +void RegisterizerAnalysis::mergeCurrentScopeIntoParent() { + auto parent = currentScope_->parent(); - if (true_stmt) { - true_stmt->accept(this); - } - if (false_stmt) { - false_stmt->accept(this); - } + // copy across current closed accceses, merging / closing as necessary + for (auto& candidate : currentScope_->closedAccesses()) { + auto& parentAccesses = parent->getAccessMapByBuf(candidate->buf()); - nested_conditions_--; - stmtStack_.pop_front(); -} + auto parentIt = parentAccesses.find(candidate->hash()); + if (parentIt != parentAccesses.end()) { + std::shared_ptr pCandidate = parentIt->second; -std::vector> RegisterizerAnalysis::getCandidates() { - std::vector> ret; - - // Group accesses by the base buffer they refer to, so it's easier to - // determine which accesses may overlap. - std::unordered_map>> - access_by_buf; - for (const auto& pair : candidates_) { - std::shared_ptr info = pair.second; - - // We can "hoist" an access up the syntax tree if it's indices do not - // depend on any loop vars. - VarFinder vf; - for (auto* i : info->indices) { - i->accept(&vf); - } + // if the access is closed inside a condition, it can only be merged if + // the parent is in the same condition. + if (candidate->conditionId() && + pCandidate->conditionId() != candidate->conditionId()) { + // the parent's access must be closed. + closeAccessIntoScope(pCandidate, parent); + parentAccesses.erase(parentIt); - const Stmt* ancestor = info->parent; - const Stmt* target = nullptr; - while (ancestor) { - if (const For* f = dynamic_cast(ancestor)) { - if (vf.vars().count(f->var()) != 0) { - break; - } - target = f->get_parent(); + // the childs access inserted into the parent scope. + closeAccessIntoScope(candidate, parent); + continue; } - ancestor = ancestor->get_parent(); + // merge totally overlapping accesses. + parentIt->second->merge(candidate); + closeAccessIntoScope(parentIt->second, parent); + parentAccesses.erase(parentIt); + continue; } - if (info->parent != target) { - if (const Block* new_parent = dynamic_cast(target)) { - info->parent = new_parent; + // we didn't find a perfect match, but we need to check all open accesses of + // this buf for partial overlap. + for (auto it = parentAccesses.begin(); it != parentAccesses.end();) { + std::shared_ptr pCandidate = it->second; + // Partial overlap of parent access: close parent access. + if (candidate->overlaps(pCandidate)) { + closeAccessIntoScope(pCandidate, parent); + it = parentAccesses.erase(it); + continue; } + ++it; } - // Now that analysis is complete we must normalize the costs by the - // parent Block we plan to insert the scalar var into. - info->store_cost = IRSimplifier::simplify( - new Div(info->store_cost, costByBlock_[info->parent])); - - if (!info->loads.empty()) { - info->load_cost = IRSimplifier::simplify( - new Div(info->load_cost, costByBlock_[info->parent])); - } - - access_by_buf[info->buf].push_back(info); + // Insert the childs closed access into the parent scope. + closeAccessIntoScope(candidate, parent); } - // For each buffer, for each access, determine if another access to the - // buffer could possibly write to the same region. - for (const auto& pair : access_by_buf) { + // copy across current open accesses, merging as necessary. + // for each Buf with an open access: + for (auto& pair : currentScope_->openAccesses()) { const Buf* buf = pair.first; - const std::vector>& accesses = pair.second; - for (const auto& info : accesses) { - // Filter out low cost accesses. - if (info->store_cost->isConstant() && - immediateAs(info->store_cost) <= 1 && - info->load_cost->isConstant() && - immediateAs(info->load_cost) <= 1) { - info->invalid = true; - continue; - } + if (pair.second.empty()) { + continue; + } - // TODO: this is n^2 by the number of accesses to a single buffer - // program wide, may be an issue in large programs. - for (const auto& i2 : accesses) { - if (info == i2) { + auto& parentAccesses = parent->getAccessMapByBuf(buf); + + // for each open access in the child scope for this Buf: + for (auto& hpair : pair.second) { + bool handled{false}; + std::shared_ptr candidate = hpair.second; + + for (auto it = parentAccesses.begin(); it != parentAccesses.end();) { + std::shared_ptr pCandidate = it->second; + + // If it completely overlaps then merge. + if (candidate->hash() == pCandidate->hash()) { + // if both accesses are found in conditional blocks, they cannot be + // merged, but the earlier must be closed. + if (pCandidate->conditionId() != parent->conditionId() && + pCandidate->conditionId() != candidate->conditionId()) { + closeAccessIntoScope(pCandidate, parent); + it = parentAccesses.erase(it); + continue; + } + pCandidate->merge(candidate); + handled = true; + ++it; continue; } - // All accesses to a buf must have the same dimensionality. - assert(info->indices.size() == i2->indices.size()); - - // They don't overlap if there is a guaranteed difference in any - // dimension. - bool overlap = true; - for (size_t i = 0; i < info->indices.size(); ++i) { - const Expr* diff = new Sub(info->indices[i], i2->indices[i]); - diff = IRSimplifier::simplify(diff); - if (diff->isConstant() && !immediateEquals(diff, 0)) { - overlap = false; - break; - } + // It can overlap an access in the parent: close the parent access. + // The child access may still be open. + if (candidate->overlaps(pCandidate)) { + closeAccessIntoScope(pCandidate, parent); + it = parentAccesses.erase(it); + continue; } - if (overlap) { - info->invalid = true; + ++it; + } + + // If this access depends on a locally scoped variable, it cannot be + // lifted out of the loop. + for (auto* v : currentScope_->localVars()) { + if (candidate->dependsOnVar(v)) { + closeAccessIntoScope(candidate, parent); + handled = true; break; } } - } - } - // Return valid access candidates in the order they were first seen. - for (const auto& info : encounterOrder_) { - if (!info->invalid) { - ret.push_back(info); + if (!handled) { + // If the inner scope was not conditional, but the outer scope is: all + // current accesses are now conditional in the parent scope. + if (candidate->conditionId() == 0) { + candidate->setConditionId(parent->conditionId()); + } + parentAccesses[candidate->hash()] = candidate; + } } } - return ret; + currentScope_ = parent; } -const Expr* RegisterizerReplacer::mutate(const Load* v) { - if (v->buf() != info_->buf) { - return IRMutator::mutate(v); - } +std::vector> RegisterizerAnalysis::getCandidates() { + currentScope_->filterClosed(); + std::sort( + currentScope_->closedAccesses().begin(), + currentScope_->closedAccesses().end(), + [](auto i1, auto i2) { return i1->accessOrder() < i2->accessOrder(); }); + return currentScope_->closedAccesses(); +} - initializerReady_ = false; +// RegisterizerReplacer - // sanity check indices for the same buf must have the same dimensionality. - assert(v->indices().size() == info_->indices.size()); - for (size_t i = 0; i < info_->indices.size(); ++i) { - if (!exprEquals(v->indices()[i], info_->indices[i])) { - return IRMutator::mutate(v); - } +const Expr* RegisterizerReplacer::mutate(const Load* v) { + auto it = loadToAccess_.find(v); + if (it == loadToAccess_.end()) { + // This access cannot be registerized. + return v; } - return var_; + auto& info = it->second; + + return info->replacement().var; } Stmt* RegisterizerReplacer::mutate(const Store* v) { - if (v->buf() != info_->buf) { - return IRMutator::mutate(v); + if (eliminatedIntializers_.count(v) != 0) { + // This store is the intializer for a scalar var that is already inserted. + return nullptr; } - if (initializerReady_ && info_->parent == v->get_parent()) { - initializer_ = v; - initializerReady_ = false; - // This is the easiest way to return an empty statement; - return new Block({}); + auto it = storeToAccess_.find(v); + if (it == storeToAccess_.end()) { + // This access cannot be registerized. + return IRMutator::mutate(v); } - initializerReady_ = false; + auto& info = it->second; - // sanity check indices for the same buf must have the same dimensionality. - assert(v->indices().size() == info_->indices.size()); - for (size_t i = 0; i < info_->indices.size(); ++i) { - if (!exprEquals(v->indices()[i], info_->indices[i])) { - return IRMutator::mutate(v); - } - } const Expr* new_val = v->value()->accept_mutator(this); - Store* s = new Store(var_wrapper_, {}, new_val, v->mask()); - return s; + return new Store(info->replacement().var_wrapper, {}, new_val, v->mask()); } -// Finds the Stmt in parent which contains stmt. -const Stmt* RegisterizerReplacer::findInsertionPoint( - const Stmt* stmt, - const Block* parent) { - while (stmt) { - if (stmt->get_parent() == parent) { - return stmt; +Stmt* RegisterizerReplacer::mutate(const Block* v) { + auto& scope = parentToAccesses_[v]; + + std::vector stmts; + for (Stmt* stmt : v->stmts()) { + { + // Insert the initializer for any Scalars scoped to this block. + auto it = scope.initializerPoints_.find(stmt); + if (it != scope.initializerPoints_.end()) { + for (auto& info : it->second) { + Stmt* initializer = + info->replacement().initializer->accept_mutator(this); + stmts.push_back(initializer); + } + scope.initializerPoints_.erase(it); + } } - stmt = stmt->get_parent(); - } - return nullptr; -} -Stmt* RegisterizerReplacer::mutate(const Block* v) { - // We need to mutate this block in place, rather than clone - since other - // AccessInfo objects may hold a pointer to it. - Block* v1 = const_cast(v); // NOLINT - assert(v1); - - Stmt* first_changed{nullptr}; - Stmt* last_changed{nullptr}; - std::list stmts = v1->stmts(); - for (Stmt* stmt : stmts) { - dirty_ = false; Stmt* stmt_new = stmt->accept_mutator(this); - if (dirty_) { - first_changed = first_changed ? first_changed : stmt_new; - last_changed = stmt_new; + if (stmt_new) { + if (stmt_new->get_parent()) { + stmt_new = Stmt::clone(stmt_new); + } + stmts.push_back(stmt_new); } - if (stmt_new == stmt) { - continue; + { + // Insert the finalizer for any Scalars scoped to this block. + auto it = scope.finalizePoints_.find(stmt); + if (it != scope.finalizePoints_.end()) { + for (auto& info : it->second) { + Store* finalizer = new Store( + info->buf(), + info->indices(), + info->replacement().var, + new IntImm(1)); + stmts.push_back(finalizer); + } + scope.finalizePoints_.erase(it); + } } - v1->replace_stmt(stmt, stmt_new); - first_changed = first_changed ? first_changed : stmt_new; - last_changed = stmt_new; } - dirty_ = first_changed != nullptr; + return new Block(stmts); +} - if (v != info_->parent) { - return v1; - } +void RegisterizerReplacer::buildReplacements() { + // Traverse the list of replacements, creating vars and updating our local + // maps. + for (auto& info : infoSet_) { + Var* v = new Var( + info->buf()->name_hint() + "_" + + c10::to_string(getBufferAccessCount(info->buf())), + info->buf()->dtype()); + + info->replacement().var = v; + + // we need to wrap the Var in a Buf so we can Load or Store it. + info->replacement().var_wrapper = new Buf(v, {}, info->buf()->dtype()); + + bool first = true; + for (auto* s : info->stores()) { + if (first && info->first_usage() == s && !info->firstUsageOverlapped()) { + info->replacement().initializer = new Let(v, s->value()); + eliminatedIntializers_.insert(s); + } else { + storeToAccess_[s] = info; + } - Stmt* let; - // If we didn't find an initial store: intialize with the original buffer. - if (!initializer_) { - let = new Let( - var_, - new Load( - info_->buf->dtype(), info_->buf, info_->indices, new IntImm(1))); - } else { - let = new Let(var_, initializer_->value()); - } - v1->insert_stmt_before(let, first_changed); + first = false; + } + + for (auto* s : info->loads()) { + loadToAccess_[s] = info; + } + + auto& scope = parentToAccesses_[info->block()]; + scope.initializerPoints_[info->first_usage()].push_back(info); + + // Only finalize if the scalar is written. + if (!info->stores().empty()) { + // push front to finalize in reverse order of encounter. + scope.finalizePoints_[info->last_usage()].push_front(info); + } - // If it was written to the buffer, make sure we write it out. - if (info_->stores.size() > 0) { - v1->insert_stmt_after( - new Store(info_->buf, info_->indices, var_, new IntImm(1)), - last_changed); + // create a default initializer by reading the access. + if (info->replacement().initializer == nullptr) { + info->replacement().initializer = new Let( + v, + new Load( + info->buf()->dtype(), + info->buf(), + info->indices(), + new IntImm(1))); + } } - return v1; } +} // namespace registerizer + // Apply scalar replacement to all accesses in s. Stmt* registerize(Stmt* s) { - RegisterizerAnalysis analysis; + s = IRSimplifier::simplify(s); + + // The outermost node must be a Block so we have somewhere to put outer scope + // scalars. + if (!dynamic_cast(s)) { + s = new Block({s}); + } + registerizer::RegisterizerAnalysis analysis; s->accept(&analysis); auto candidates = analysis.getCandidates(); - for (const auto& info : candidates) { - RegisterizerReplacer replacer(info); - s = s->accept_mutator(&replacer); - } + + registerizer::RegisterizerReplacer replacer(candidates); + s = s->accept_mutator(&replacer); return s; } diff --git a/torch/csrc/jit/tensorexpr/registerizer.h b/torch/csrc/jit/tensorexpr/registerizer.h index 118686a3e4e1..551a9fbb3277 100644 --- a/torch/csrc/jit/tensorexpr/registerizer.h +++ b/torch/csrc/jit/tensorexpr/registerizer.h @@ -11,6 +11,7 @@ namespace torch { namespace jit { namespace tensorexpr { +namespace registerizer { /* The Registerizer performs scalar replacement by looking for common Stores and Loads to a single item in a buffer and replacing them with a local temporary @@ -38,58 +39,289 @@ For example it can replace: This is particularly useful on GPUs when parallelizing, since after replacing loops with metavars we have a lot of accesses like this. */ -// Holds analysis information about accesses to a specific range of a -// buffer, including the number of loads and stores and the lowest common -// parent Block. -struct AccessInfo { +class Scope; + +/* Holds analysis information about accesses to a specific range of a + buffer, including the number of loads and stores and the lowest common parent + Block. + */ +class AccessInfo { + public: AccessInfo() = default; - AccessInfo(const Buf* b, const std::vector& i) - : buf(b), - indices(i), - store_cost(new IntImm(0)), - load_cost(new IntImm(0)) {} + AccessInfo( + SimplifierHashType h, + const Buf* b, + const std::vector& i, + size_t accessOrder) + : hash_(h), + buf_(b), + indices_(i), + store_cost_(new IntImm(0)), + load_cost_(new IntImm(0)), + accessOrder_(accessOrder) {} + + // Adds a Store to this access, which is in the provided scope. + void addStore(const Store* store, const std::shared_ptr& scope); + + // Adds a Load to this access, which occurs in the usage Stmt in the provided + // scope. + void addLoad( + const Load* load, + const std::shared_ptr& scope, + const Stmt* usage); + + // Merge another AccessInfo into this one. + void merge(const std::shared_ptr& other); + + // Returns true if the other AccessInfo's bounds may overlap this one. + bool overlaps(const std::shared_ptr& other); - void addStore(const Store* s, const Block* p, const Expr* cost) { - store_cost = IRSimplifier::simplify(new Add(store_cost, cost)); - stores.push_back(s); - parent = parent ? Block::getSharedParent(parent, p) : p; - first_usage = first_usage ? first_usage : s; + // Returns true if the indices of this access depend on the provided Var. + bool dependsOnVar(const Var* v); + + // Clone this AccessInfo, and set this as the new accesses' hiddenAccess. + static std::shared_ptr cloneWithHiddenInfo( + const std::shared_ptr& orig); + + // print for debugging. + void print() const; + + SimplifierHashType hash() const { + return hash_; } - void addLoad( - const Load* l, - const Block* p, - const Expr* cost, - const Stmt* usage) { - load_cost = IRSimplifier::simplify(new Add(load_cost, cost)); - loads.push_back(l); - parent = parent ? Block::getSharedParent(parent, p) : p; - first_usage = first_usage ? first_usage : usage; + const Buf* buf() const { + return buf_; + } + + const std::vector& indices() const { + return indices_; + } + + const Block* block() const { + return block_; + } + + void setEnclosingBlock(const Block* b) { + block_ = b; + } + + const Stmt* first_usage() const { + return first_usage_; + } + const Stmt* last_usage() const { + return last_usage_; + } + + void setUsageMarks(const Stmt* first, const Stmt* last) { + first_usage_ = first; + last_usage_ = last; + } + + bool firstUsageOverlapped() const { + return firstUsageOverlapped_; + } + + const Expr* store_cost() const { + return store_cost_; + } + + const Expr* load_cost() const { + return load_cost_; + } + + const std::vector& stores() const { + return stores_; + } + + const std::vector& loads() const { + return loads_; + } + + void hoistCosts(const Expr* extent) { + store_cost_ = IRSimplifier::simplify(new Mul(store_cost_, extent)); + load_cost_ = IRSimplifier::simplify(new Mul(load_cost_, extent)); + } + + size_t conditionId() const { + return conditionId_; + } + + void setConditionId(size_t c) { + conditionId_ = c; + } + + size_t accessOrder() const { + return accessOrder_; + } + + std::shared_ptr hiddenAccess() const { + return hiddenAccess_; + } + + // Holds state relating to the scalar variable we will insert to replace some + // number of loads and stores. + struct ScalarReplacement { + Var* var{nullptr}; + Buf* var_wrapper{nullptr}; + Let* initializer{nullptr}; + }; + + ScalarReplacement& replacement() { + return replacement_; } - const Buf* buf; - std::vector indices; - const Block* parent{nullptr}; + private: + SimplifierHashType hash_; + const Buf* buf_; + std::vector indices_; + const Block* block_{nullptr}; + + const Stmt* first_usage_{nullptr}; + const Stmt* last_usage_{nullptr}; + + // Whether or not this access is overlapped in the first Stmt it appears. This + // means we cannot use it's first Store as the initializer. + bool firstUsageOverlapped_{false}; + + // The cost in real ops that this access represents, to enable + // filtering accesses that wont save any loads or stores. + const Expr* store_cost_; + const Expr* load_cost_; + + // The actual Stores and Loads which represent this access. + // Be careful with these, any mutator will invalidate these pointers. + std::vector stores_; + std::vector loads_; + + // An identifier representing the conditional block, if any, this access + // depends on. + size_t conditionId_{0}; + + // An identifier representing the order this access was first encountered, for + // sorting returned results. + size_t accessOrder_{0}; + + // Sometimes when traversing the tree we need to record what would happen if + // we hoisted an access, but sometimes it doesn't work out. This lets us + // "undo" some mutation and return to the internal hidden AccessInfo. + // It will be removed after any further additions to this AccessInfo. + std::shared_ptr hiddenAccess_; + + ScalarReplacement replacement_; +}; - const Stmt* first_usage{nullptr}; +using AccessHashMap = + std::unordered_map>; + +// Represents a scope block and holds all accesses contained within it. +class Scope { + public: + Scope(const Block* b, std::shared_ptr parent, size_t conditionId = 0) + : block_(b), parent_(parent), conditionId_(conditionId) {} - const Expr* store_cost; - const Expr* load_cost; + AccessHashMap& getAccessMapByBuf(const Buf* b); - std::vector stores; - std::vector loads; + std::unordered_map& openAccesses() { + return openAccesses_; + } + + std::vector>& closedAccesses() { + return closedAccesses_; + } + + const Block* block() const { + return block_; + } + + std::shared_ptr parent() const { + return parent_; + } + + size_t conditionId() const { + return conditionId_; + } - bool invalid{false}; + const std::unordered_set& localVars() const { + return localVars_; + } + void addLocalVar(const Var* v) { + localVars_.insert(v); + } + + void closeAccess(const std::shared_ptr& info); + + void filterClosed(); + + private: + // Map of map to access, narrowing by Buf then by hash(Buf+Indices). + // This allows us to find a candidate access easily, and also check for + // overlap with other accesses to the same buf. Buf -> + // Hash -> + // Access + std::unordered_map openAccesses_; + std::vector> closedAccesses_; + + // The Block object this scope represents. + const Block* block_; + + // The enclosing scope object. + std::shared_ptr parent_; + + // An identifier representing the condition block this scope depends on. + size_t conditionId_; + + // A set of variables local to this scope (e.g. loop vars). + std::unordered_set localVars_; }; -// Walks the IR generating AccessInfo for each access. +/* Analyzes the graph and collects accesses to the same symbolic tensor element + * which can be replaced by a single local scalar. + * + * This works by recursively walking the tree in postfix order, building sets of + * accesses to the same symbolic element by scope and then merging lower scopes + * into their enclosing scope. + * + * It is safe to move two accesses of the same Tensor element to a local scalar + * Var if between all usages of the element there are no other Loads or Stores + * that may refer to it. In the comments I refer to this as overlapping the + * access, or "cutting" the existing AccessInfo. In the case where a candidate + * for registerization is cut, it may be possible to finalize the access early + * by writing it back to the Tensor and then create a new scalar variable after + * the overlapping access is complete. We will attempt to do this when it saves + * memory accesses. + * + * There are a few cases that make this more challenging: + * + * - For: Loops change the number of real usages of a buffer by the loop + * extent, but only if we can pull the definition and finalization of the scalar + * variable out of the loop block. + * + * - Cond: Conditions complicate lifting scalars out of internal scopes. + * Generally we cannot lift an access outside of a conditional scope unless + * there is already a reference to that same access at the higher scope, since + * we don't know if the condition was guarding an array access not safe at the + * higher scope. In the comments I refer to this as the condition "hiding" the + * access, and the outer access "unhiding" it. + * + * - IfThenElse: Same situation as Cond, except since IfThenElse is an Expr + * rather than a Stmt we cannot insert the scalar definition or finalizer + * within the conditional scope. Acccesses inside an IfThenElse can be safely + * combined with external accesses but cannot exist completely within. + * + * - Let: Accesses dependent on local variables via Let Stmts, or loop vars, + * cannot be raised outside of the scope of the dependent var. + */ class TORCH_API RegisterizerAnalysis : public IRVisitor { public: - RegisterizerAnalysis() : loopCost_(new IntImm(1)) {} + RegisterizerAnalysis() + : currentScope_(std::make_shared(nullptr, nullptr, 0)) {} virtual ~RegisterizerAnalysis() {} void visit(const For* v) override; + void visit(const Cond* v) override; + void visit(const Block* v) override; void visit(const Store* v) override; @@ -98,7 +330,7 @@ class TORCH_API RegisterizerAnalysis : public IRVisitor { void visit(const IfThenElse* v) override; - void visit(const Cond* v) override; + void visit(const Let* v) override; #define STMT_ON_STACK(Op) \ virtual void visit(const Op* v) override { \ @@ -110,54 +342,77 @@ class TORCH_API RegisterizerAnalysis : public IRVisitor { STMT_ON_STACK(AtomicAdd); STMT_ON_STACK(Allocate); STMT_ON_STACK(Free); - STMT_ON_STACK(Let); #undef STMT_ON_STACK std::vector> getCandidates(); private: - std::unordered_map> - candidates_; - std::unordered_map costByBlock_; - std::vector> encounterOrder_; + void mergeCurrentScopeIntoParent(); + void mergeHiddenScope(bool allowClosed); + void closeAccessIntoScope( + const std::shared_ptr& info, + const std::shared_ptr& scope); - const Expr* loopCost_; + std::unordered_set exprConditionals_; + // A stack of enclosing Stmts for tracking the usage Stmt of Loads. std::deque stmtStack_; - const Block* enclosingBlock_; + + // The current scope being analyzed. + std::shared_ptr currentScope_; + HashProvider hasher_; - size_t nested_conditions_{0}; + size_t conditionId_{0}; + size_t accessOrder_{0}; }; -// Walks the IR an replaces a single Acccess with a local scalar Var. +/* Replaces each registerizable access with a Scalar variable, including + * definition, initializer and finalizer. + */ class TORCH_API RegisterizerReplacer : public IRMutator { public: - RegisterizerReplacer(std::shared_ptr i) : info_(i) { - var_ = new Var(info_->buf->name_hint() + "_", info_->buf->dtype()); - var_wrapper_ = new Buf(var_, {}, info_->buf->dtype()); - - initializer_ = nullptr; + RegisterizerReplacer(std::vector>& vec) + : infoSet_(vec) { + buildReplacements(); } const Expr* mutate(const Load* v) override; Stmt* mutate(const Store* v) override; - // Finds the Stmt in parent which contains stmt. - const Stmt* findInsertionPoint(const Stmt* stmt, const Block* parent); - Stmt* mutate(const Block* v) override; private: - std::shared_ptr info_; - Var* var_; - Buf* var_wrapper_; - const Store* initializer_; - bool dirty_{false}; - bool initializerReady_{true}; + struct ReplacerScope { + std::unordered_map>> + initializerPoints_; + std::unordered_map>> + finalizePoints_; + }; + + // Creates the various ReplacerScope objects and builds internal maps. + void buildReplacements(); + + // State relating to the accesses yet to be replaced. + std::vector>& infoSet_; + std::unordered_map> storeToAccess_; + std::unordered_map> loadToAccess_; + std::unordered_map parentToAccesses_; + + // Holds the set of Stores that should be pulled into an initializer, so they + // can be eliminated. + std::set eliminatedIntializers_; + + // Tracks the number of times we've seen each buffer, so we can name the + // scalar Vars appropriately. + std::unordered_map bufferAccessCounts_; + unsigned int getBufferAccessCount(const Buf* b) { + return ++bufferAccessCounts_[b]; + } }; +} // namespace registerizer // Apply scalar replacement to all accesses in s. // To produce safe code, this must occur after handling parallelized axes and diff --git a/torch/csrc/jit/tensorexpr/stmt.h b/torch/csrc/jit/tensorexpr/stmt.h index 1d3712134335..7aec71f6b56d 100644 --- a/torch/csrc/jit/tensorexpr/stmt.h +++ b/torch/csrc/jit/tensorexpr/stmt.h @@ -244,6 +244,14 @@ class TORCH_API Block : public StmtNode { return nullptr; } + // returns the immediate child containing statement s. + const Stmt* getEnclosedRoot(const Stmt* s) const { + while (s && s->get_parent() != this) { + s = s->get_parent(); + } + return s; + } + private: std::list stmts_; }; From 154347d82f7af3cad03065c9f13c60758c8ff8e1 Mon Sep 17 00:00:00 2001 From: Rohan Varma Date: Wed, 7 Oct 2020 19:57:37 -0700 Subject: [PATCH 48/69] Fix distributed documentation for asynchronous collective Work objects (#45709) Summary: Closes https://github.com/pytorch/pytorch/issues/42247. Clarifies some documentation related to `Work` object semantics (outputs of async collective functions). Clarifies the difference between CPU operations and CUDA operations (on Gloo or NCCL backend), and provides an example where the difference in CUDA operation's wait() semantics is necessary to understand for correct code. ![sync](https://user-images.githubusercontent.com/8039770/94875710-6f64e780-040a-11eb-8fb5-e94fd53534e5.png) Pull Request resolved: https://github.com/pytorch/pytorch/pull/45709 Reviewed By: ngimel Differential Revision: D24171256 Pulled By: rohan-varma fbshipit-source-id: 6365a569ef477b59eb2ac0a8a9a1c1f34eb60e22 --- docs/source/distributed.rst | 58 ++++++++++++++++++++++++++++--------- 1 file changed, 45 insertions(+), 13 deletions(-) diff --git a/docs/source/distributed.rst b/docs/source/distributed.rst index c83b5a1d34de..8117a3a63668 100644 --- a/docs/source/distributed.rst +++ b/docs/source/distributed.rst @@ -330,27 +330,59 @@ as they should never be created manually, but they are guaranteed to support two Synchronous and asynchronous collective operations -------------------------------------------------- -Every collective operation function supports the following two kinds of operations: - -synchronous operation - the default mode, when ``async_op`` is set to False. -when the function returns, it is guaranteed that -the collective operation is performed (not necessarily completed if it's a CUDA op since all -CUDA ops are asynchronous), and any further function calls depending on the data of the -collective operation can be called. In the synchronous mode, the collective function does not -return anything - -asynchronous operation - when ``async_op`` is set to True. The collective operation function +Every collective operation function supports the following two kinds of operations, +depending on the setting of the ``async_op`` flag passed into the collective: + +**Synchronous operation** - the default mode, when ``async_op`` is set to ``False``. +When the function returns, it is guaranteed that +the collective operation is performed. In the case of CUDA operations, it is not guaranteed +that the CUDA operation is completed, since CUDA operations are asynchronous. For CPU collectives, any +further function calls utilizing the output of the collective call will behave as expected. For CUDA collectives, +function calls utilizing the output on the same CUDA stream will behave as expected. Users must take care of +synchronization under the scenario of running under different streams. For details on CUDA semantics such as stream +synchronization, see `CUDA Semantics `__. +See the below script to see examples of differences in these semantics for CPU and CUDA operations. + +**Asynchronous operation** - when ``async_op`` is set to True. The collective operation function returns a distributed request object. In general, you don't need to create it manually and it is guaranteed to support two methods: -* ``is_completed()`` - returns True if the operation has finished -* ``wait()`` - will block the process until the operation is finished. +* ``is_completed()`` - in the case of CPU collectives, returns ``True`` if completed. In the case of CUDA operations, + returns ``True`` if the operation has been successfully enqueued onto a CUDA stream and the output can be utilized on the + default stream without further synchronization. +* ``wait()`` - in the case of CPU collectives, will block the process until the operation is completed. In the case + of CUDA collectives, will block until the operation has been successfully enqueued onto a CUDA stream and the + output can be utilized on the default stream without further synchronization. + +**Example** + +The following code can serve as a reference regarding semantics for CUDA operations when using distributed collectives. +It shows the explicit need to synchronize when using collective outputs on different CUDA streams: + +:: + + # Code runs on each rank. + dist.init_process_group("nccl", rank=rank, world_size=2) + output = torch.tensor([rank]).cuda(rank) + s = torch.cuda.Stream() + handle = dist.all_reduce(output, async_op=True) + # Wait ensures the operation is enqueued, but not necessarily complete. + handle.wait() + # Using result on non-default stream. + with torch.cuda.stream(s): + s.wait_stream(torch.cuda.default_stream()) + output.add_(100) + if rank == 0: + # if the explicit call to wait_stream was omitted, the output below will be + # non-deterministically 1 or 101, depending on whether the allreduce overwrote + # the value after the add completed. + print(output) Collective functions -------------------- -.. autofunction:: broadcast +.. autofunction:: broadcast .. autofunction:: broadcast_object_list From 903acc6b83e058cb6ed1cb7faa5938b425e695fd Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Wed, 7 Oct 2020 20:35:36 -0700 Subject: [PATCH 49/69] CUDA BFloat16 support of clamp, remainder, lshift, rshift (#45247) Summary: Add CUDA BFloat16 support of clamp, remainder, lshift, rshift Pull Request resolved: https://github.com/pytorch/pytorch/pull/45247 Reviewed By: dzhulgakov Differential Revision: D24174258 Pulled By: ngimel fbshipit-source-id: bfcd2d1b3746bb0527d590533f3c38b9c4d0a638 --- .../ATen/native/cuda/BinaryRemainderKernel.cu | 2 +- .../ATen/native/cuda/BinaryShiftOpsKernels.cu | 10 ++++++---- aten/src/ATen/native/cuda/UnaryOpsKernel.cu | 6 +++--- test/test_torch.py | 16 +++++++++------- 4 files changed, 19 insertions(+), 15 deletions(-) diff --git a/aten/src/ATen/native/cuda/BinaryRemainderKernel.cu b/aten/src/ATen/native/cuda/BinaryRemainderKernel.cu index 86b2703797dc..519c3588b02c 100644 --- a/aten/src/ATen/native/cuda/BinaryRemainderKernel.cu +++ b/aten/src/ATen/native/cuda/BinaryRemainderKernel.cu @@ -21,7 +21,7 @@ void remainder_kernel_cuda(TensorIterator& iter) { }); }); } else { - AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "remainder_cuda", [&]() { + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "remainder_cuda", [&]() { gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) __ubsan_ignore_float_divide_by_zero__ -> scalar_t { auto mod = ::fmod(a, b); diff --git a/aten/src/ATen/native/cuda/BinaryShiftOpsKernels.cu b/aten/src/ATen/native/cuda/BinaryShiftOpsKernels.cu index 4b6533e62db6..67ff7954294d 100644 --- a/aten/src/ATen/native/cuda/BinaryShiftOpsKernels.cu +++ b/aten/src/ATen/native/cuda/BinaryShiftOpsKernels.cu @@ -13,8 +13,9 @@ namespace at { namespace native { void lshift_kernel_cuda(TensorIterator& iter) { if (iter.dtype() == ScalarType::Float || iter.dtype() == ScalarType::Double || - iter.dtype() == ScalarType::Half) { - AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::Half, iter.dtype(), "lshift_cuda", [&]() { + iter.dtype() == ScalarType::Half || + iter.dtype() == ScalarType::BFloat16) { + AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "lshift_cuda", [&]() { gpu_kernel_with_scalars( iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { @@ -34,8 +35,9 @@ void lshift_kernel_cuda(TensorIterator& iter) { void rshift_kernel_cuda(TensorIterator& iter) { if (iter.dtype() == ScalarType::Float || iter.dtype() == ScalarType::Double || - iter.dtype() == ScalarType::Half) { - AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::Half, iter.dtype(), "rshift_cuda", [&]() { + iter.dtype() == ScalarType::Half || + iter.dtype() == ScalarType::BFloat16) { + AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "rshift_cuda", [&]() { gpu_kernel_with_scalars( iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { diff --git a/aten/src/ATen/native/cuda/UnaryOpsKernel.cu b/aten/src/ATen/native/cuda/UnaryOpsKernel.cu index 6f5c9221dee6..c3c8dd1e5094 100644 --- a/aten/src/ATen/native/cuda/UnaryOpsKernel.cu +++ b/aten/src/ATen/native/cuda/UnaryOpsKernel.cu @@ -155,7 +155,7 @@ void erfinv_kernel_cuda(TensorIterator& iter) { } void clamp_kernel_cuda(TensorIterator& iter, Scalar min_value, Scalar max_value) { - AT_DISPATCH_ALL_TYPES_AND(kHalf, iter.dtype(), "clamp_cuda", [&]() { + AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "clamp_cuda", [&]() { auto lower = min_value.to(); auto upper = max_value.to(); gpu_kernel(iter, [=]GPU_LAMBDA(scalar_t v) -> scalar_t { @@ -170,7 +170,7 @@ void clamp_kernel_cuda(TensorIterator& iter, Scalar min_value, Scalar max_value) } void clamp_min_kernel_cuda(TensorIterator& iter, Scalar min_value) { - AT_DISPATCH_ALL_TYPES_AND(kHalf, iter.dtype(), "clamp_min_cuda", [&]() { + AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "clamp_min_cuda", [&]() { auto lower = min_value.to(); gpu_kernel(iter, [=]GPU_LAMBDA(scalar_t v) -> scalar_t { // Propagate nan, which doesn't propagate automatically for ROCm @@ -184,7 +184,7 @@ void clamp_min_kernel_cuda(TensorIterator& iter, Scalar min_value) { } void clamp_max_kernel_cuda(TensorIterator& iter, Scalar max_value) { - AT_DISPATCH_ALL_TYPES_AND(kHalf, iter.dtype(), "clamp_max_cuda", [&]() { + AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "clamp_max_cuda", [&]() { auto upper = max_value.to(); gpu_kernel(iter, [=]GPU_LAMBDA(scalar_t v) -> scalar_t { // Propagate nan, which doesn't propagate automatically for ROCm diff --git a/test/test_torch.py b/test/test_torch.py index 3ff5a1d73822..fd89dff7cb92 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -19947,7 +19947,7 @@ def test_movedim_view(self, device): _float_types2 = _float_types + [torch.bfloat16] if TEST_WITH_ROCM else _float_types _signed_types = [ - torch.half, torch.float, torch.double, + torch.half, torch.bfloat16, torch.float, torch.double, torch.int8, torch.short, torch.int, torch.long ] @@ -20189,8 +20189,10 @@ def inner(self, device, dtype): ('chunk', 'neg_dim', _medium_2d, lambda t, d: [4, -2], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False), ('clamp', 'neg', _medium_2d, lambda t, d: [-1, 5], 1e-5, 1e-2, 1e-5, _signed_types, [torch.bfloat16]), ('clamp', 'pos', _medium_2d, lambda t, d: [1, 5], 1e-5, 1e-2, 1e-5, _unsigned_types, [torch.bfloat16]), - ('clamp_min', '', _medium_2d, lambda t, d: [1], 1e-2, 1e-2, 1e-5, _types, [torch.bfloat16]), - ('clamp_max', '', _medium_2d, lambda t, d: [1], 1e-2, 1e-2, 1e-5, _types, [torch.bfloat16]), + ('clamp_min', '', _medium_2d, lambda t, d: [1], 1e-2, 1e-2, 1e-5, + torch.testing.get_all_dtypes(include_complex=False, include_bool=False, include_bfloat16=True), [torch.bfloat16]), + ('clamp_max', '', _medium_2d, lambda t, d: [1], 1e-2, 1e-2, 1e-5, + torch.testing.get_all_dtypes(include_complex=False, include_bool=False, include_bfloat16=True), [torch.bfloat16]), ('clone', '', _medium_2d, lambda t, d: [], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False), ('contiguous', '', _medium_2d, lambda t, d: [], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False), ('conj', '', _small_3d, lambda t, d: [], 1e-5, 0, 1e-5, _types_no_half, [torch.bfloat16], False), @@ -20275,14 +20277,14 @@ def inner(self, device, dtype): 1e-5, 1e-5, 1e-5, _float_types_no_half), ('mvlgamma', '2d_p=2', lambda t, d: _small_2d(t, d).clamp(0.6, 10), lambda t, d: [2], 1e-5, 1e-5, 1e-5, _float_types_no_half), - ('remainder', 'value', _small_3d, lambda t, d: [3], 1e-1, 1e-5, 1e-5, _signed_types), - ('remainder', 'negative_value', _small_3d, lambda t, d: [-3], 1e-1, 1e-5, 1e-5, _signed_types), + ('remainder', 'value', _small_3d, lambda t, d: [3], 1e-1, 1e-2, 1e-5, _signed_types), + ('remainder', 'negative_value', _small_3d, lambda t, d: [-3], 1e-1, 1e-2, 1e-5, _signed_types), ('remainder', 'tensor', _small_3d, lambda t, d: [_small_3d(t, d, has_zeros=False)], - 1e-1, 1e-5, 1e-5, _signed_types), + 1e-1, 1e-2, 1e-5, _signed_types), ('remainder', 'negative_tensor', _small_3d, lambda t, d: [0 - _small_3d(t, d, has_zeros=False)], - 1e-1, 1e-5, 1e-5, _signed_types), + 1e-1, 1e-2, 1e-5, _signed_types), ('std', '', _small_3d, lambda t, d: [], 1e-3, 1e-5, 1e-5, _float_types, _cpu_types, False), ('std', 'dim', _small_3d, lambda t, d: [1], 1e-3, 1e-5, 1e-5, _float_types, _cpu_types, False), ('std', 'neg_dim', _small_3d, lambda t, d: [-1], 1e-3, 1e-5, 1e-5, _float_types, _cpu_types, False), From c59c4b0d778b445cd016c77dc42ee024ef150bf7 Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 7 Oct 2020 20:38:19 -0700 Subject: [PATCH 50/69] Fix cholesky TF32 tests (#45492) Summary: This test is changed one day before the landing of the tf32 tests PR, therefore the fix for this is not included in that PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/45492 Reviewed By: ezyang Differential Revision: D24101876 Pulled By: ngimel fbshipit-source-id: cb3615b2fb8acf17abe54cd18b1faec26582d6b6 --- test/test_torch.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_torch.py b/test/test_torch.py index fd89dff7cb92..312943d8715c 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -7794,6 +7794,7 @@ def cholesky_test_helper(n, batch_dims, upper): @skipCUDAIfNoMagma @skipCPUIfNoLapack @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + @tf32_on_and_off(0.01) def test_cholesky(self, device, dtype): from torch.testing._internal.common_utils import \ (random_symmetric_pd_matrix, From 81d40aaf96e84adde9f4fa2fe26761f1a54bc0b6 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Wed, 7 Oct 2020 20:43:54 -0700 Subject: [PATCH 51/69] Add `[zc]heevd` to the list of MKL symbols exported from torch_cpu (#46002) Summary: cpu implementation of `torch.symeig` uses `[zc]heev`, but MAGMA only have `d`-suffixed flavors of those functions Fixes https://github.com/pytorch/pytorch/issues/45922 Pull Request resolved: https://github.com/pytorch/pytorch/pull/46002 Reviewed By: walterddr Differential Revision: D24177730 Pulled By: malfet fbshipit-source-id: 0e9aeb60a83f8a4b8ac2a86288721bd362b6040b --- caffe2/CMakeLists.txt | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index fe5240118b2f..8e358c9503f7 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -1026,7 +1026,11 @@ if($ENV{TH_BINARY_BUILD}) # # These linker commands do not work on OS X, do not attempt this there. # (It shouldn't matter anyway, though, because OS X has dropped CUDA support) - set_target_properties(torch_cpu PROPERTIES LINK_FLAGS "-Wl,--undefined=mkl_lapack_slaed0 -Wl,--undefined=mkl_lapack_dlaed0 -Wl,--undefined=mkl_lapack_dormql -Wl,--undefined=mkl_lapack_sormql") + foreach(_symb slaed0 daled0 dormql sormql zheevd cheevd) + STRING(APPEND _undefined_link_flags " -Wl,--undefined=mkl_lapack_${_symb}") + endforeach(_symb) + set_target_properties(torch_cpu PROPERTIES LINK_FLAGS ${_undefined_link_flags}) + endif() endif() From 00b8ebe60c5dfa3c14a76e71e0166d0bb8cda4e3 Mon Sep 17 00:00:00 2001 From: James Reed Date: Wed, 7 Oct 2020 21:32:51 -0700 Subject: [PATCH 52/69] [FX] Preserve type annotations on generated code in Graph (#45880) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45880 Test Plan: Imported from OSS Reviewed By: dzhulgakov Differential Revision: D24127303 Pulled By: jamesr66a fbshipit-source-id: 3a042bcfb0bf9f58ac318cc814dfc3cca683c7f8 --- test/test_fx.py | 32 ++++++++++-- torch/fx/graph.py | 102 ++++++++++++++++++++++++++++--------- torch/fx/node.py | 14 ++++- torch/fx/proxy.py | 10 ++-- torch/fx/symbolic_trace.py | 16 ++++-- 5 files changed, 137 insertions(+), 37 deletions(-) diff --git a/test/test_fx.py b/test/test_fx.py index 76217fce9e80..94bea7032ab8 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -15,7 +15,7 @@ from fx.quantization import Quantizer -from typing import Any, Callable, Dict, Optional, Tuple, Union +from typing import Any, Callable, Dict, NamedTuple, List, Optional, Tuple, Union from torch.testing._internal.common_utils import run_tests, TEST_WITH_ROCM, IS_WINDOWS, IS_SANDCASTLE, IS_MACOS from torch.testing._internal.jit_utils import JitTestCase @@ -33,6 +33,10 @@ def forward(self, x): def a_non_torch_leaf(a, b): return a + b +class Pair(NamedTuple): + x : torch.Tensor + y : torch.Tensor + class TestFX(JitTestCase): def checkGraphModule(self, m: torch.nn.Module, args, kwargs=None): """Check that an nn.Module's results match the GraphModule version @@ -131,7 +135,8 @@ def test_disallow_override(self): # Custom delegate to disallow in-place tensor operations class NoMutableCallTracer(Tracer): def create_node(self, kind : str, target : Union[str, Callable], - args : Tuple[Any], kwargs : Dict[str, Any], name : Optional[str] = None) -> Node: + args : Tuple[Any], kwargs : Dict[str, Any], name : Optional[str] = None, + type_expr : Optional[Any] = None) -> Node: name = target if isinstance(target, str) else torch.typename(target) if name[-1] == '_': raise RuntimeError('In-place operations are not supported') @@ -448,7 +453,8 @@ def forward(self, a): def test_node_tagging(self): class TaggingTracer(Tracer): def create_node(self, kind : str, target : Union[str, Callable], - args : Tuple[Any], kwargs : Dict[str, Any], name : Optional[str] = None) -> Node: + args : Tuple[Any], kwargs : Dict[str, Any], name : Optional[str] = None, + type_expr : Optional[Any] = None) -> Node: n = super().create_node(kind, target, args, kwargs, name) n.tag = 'foo' return n @@ -765,6 +771,26 @@ def forward(self, x): # Test shape propogation and make sure results match actual self.assertEqual(output_shape, ref_out.shape) + def test_fn_type_annotations(self): + class Foo(torch.nn.Module): + def forward(self, p : Pair, z : torch.Tensor, i : int) -> Dict[str, torch.Tensor]: + return {'a': p.x + p.y + z + i} + + foo_scripted = torch.jit.script(Foo()) + foo_scripted(Pair(torch.rand(5), torch.rand(5)), torch.rand(5), 3) + + fxed = symbolic_trace(Foo()) + fxed_scripted = torch.jit.script(fxed) + fxed_scripted(Pair(torch.rand(5), torch.rand(5)), torch.rand(5), 3) + + def test_typename_print(self): + graph : torch.fx.Graph = torch.fx.Graph() + x : torch.fx.Node = graph.create_node('placeholder', 'x') + b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,), + type_expr=List[float]) + output : torch.fx.Node = graph.output(b) + self.assertTrue('typing.List[float]' in str(graph)) + def test_find_single_partition(self): class testModule(torch.nn.Module): def forward(self, a, b): diff --git a/torch/fx/graph.py b/torch/fx/graph.py index 600fcb27a850..9994bd8be65a 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -3,6 +3,7 @@ from typing import Callable, Any, List, Dict, Optional, Tuple, Set import builtins import torch +import types import keyword import re @@ -52,6 +53,29 @@ def _format_target(base: str, target: str) -> str: r = f'{r}.{e}' return r +# Borrowed from CPython typing module +# https://github.com/python/cpython/blob/f90dc36c15d7fee0efaf6d39e97be0bdf2683e93/Lib/typing.py#L156 +def _type_repr(obj): + """Return the repr() of an object, special-casing types (internal helper). + If obj is a type, we return a shorter version than the default + type.__repr__, based on the module and qualified name, which is + typically enough to uniquely identify a type. For everything + else, we fall back on repr(obj). + """ + # HACK: In Python 3.6, type aliases from `typing` are instances of `type`, but in + # later Python versions, type aliases are not instances of `type`!! We want + # all type aliases to fall through to `repr`, so if we have a type that is + # in the module typing, don't go down this path. + if isinstance(obj, type) and obj.__module__ != 'typing': + if obj.__module__ == 'builtins': + return obj.__qualname__ + return f'{obj.__module__}.{obj.__qualname__}' + if obj is ...: + return('...') + if isinstance(obj, types.FunctionType): + return obj.__name__ + return repr(obj) + class insert_before: def __init__(self, n : Node): self.n = n @@ -65,6 +89,9 @@ def __exit__(self, type, value, tb): class Graph: def __init__(self): + """ + Construct an empty Graph. + """ self._nodes : List[Node] = [] self._used_names : Dict[str, int] = {} # base name -> number self._insert_point : Optional[Node] = None @@ -90,12 +117,13 @@ def graph_copy(self, g : 'Graph', val_map : Dict[Node, Node]) -> Optional[Argume def create_node(self, op: str, target: Target, args: Optional[Tuple[Argument, ...]] = None, kwargs: Optional[Dict[str, Argument]] = None, - name: Optional[str] = None) -> Node: + name: Optional[str] = None, + type_expr: Optional[Any] = None) -> Node: assert op in ('call_function', 'call_method', 'get_attr', 'call_module', 'placeholder', 'output') args = () if args is None else args kwargs = {} if kwargs is None else kwargs sanitized_name = self._register_name_used(name) if name is not None else self._name(target) - n = Node(self, sanitized_name, op, target, args, kwargs) + n = Node(self, sanitized_name, op, target, args, kwargs, type_expr) if self._insert_point is not None: before_idx = self._nodes.index(self._insert_point) self._nodes.insert(before_idx, n) @@ -130,29 +158,32 @@ def erase_node(self, to_erase : Node): self._nodes.pop(idx) # sugar for above when you know the op - def placeholder(self, name: str) -> Node: - return self.create_node('placeholder', name) + def placeholder(self, name: str, type_expr: Optional[Any] = None) -> Node: + return self.create_node('placeholder', name, type_expr=type_expr) - def get_attr(self, name: str) -> Node: - return self.create_node('get_attr', name) + def get_attr(self, name: str, type_expr: Optional[Any] = None) -> Node: + return self.create_node('get_attr', name, type_expr=type_expr) def call_module(self, module_name: str, args: Optional[Tuple[Argument, ...]] = None, - kwargs: Optional[Dict[str, Argument]] = None) -> Node: - return self.create_node('call_module', module_name, args, kwargs) + kwargs: Optional[Dict[str, Argument]] = None, + type_expr: Optional[Any] = None) -> Node: + return self.create_node('call_module', module_name, args, kwargs, type_expr=type_expr) def call_method(self, method_name: str, args: Optional[Tuple[Argument, ...]] = None, - kwargs: Optional[Dict[str, Argument]] = None) -> Node: - return self.create_node('call_method', method_name, args, kwargs) + kwargs: Optional[Dict[str, Argument]] = None, + type_expr: Optional[Any] = None) -> Node: + return self.create_node('call_method', method_name, args, kwargs, type_expr=type_expr) def call_function(self, the_function: Callable[..., Any], args: Optional[Tuple[Argument, ...]] = None, - kwargs: Optional[Dict[str, Argument]] = None) -> Node: - return self.create_node('call_function', the_function, args, kwargs) + kwargs: Optional[Dict[str, Argument]] = None, + type_expr: Optional[Any] = None) -> Node: + return self.create_node('call_function', the_function, args, kwargs, type_expr=type_expr) def node_copy(self, node: Node, arg_transform: Callable[[Node], Argument] = lambda x: x) -> Node: """ copy a node from one graph into another. arg_transform needs to transform arguments from the graph of node @@ -181,10 +212,10 @@ def node_copy(self, node: Node, arg_transform: Callable[[Node], Argument] = lamb except ValueError: pass name = self._name(sanitized_name) - return self.create_node(node.op, node.target, args, kwargs, name) + return self.create_node(node.op, node.target, args, kwargs, name, node.type) - def output(self, result: Argument): - return self.create_node(op='output', target='output', args=(result,)) + def output(self, result: Argument, type_expr: Optional[Any] = None): + return self.create_node(op='output', target='output', args=(result,), type_expr=type_expr) def _name(self, target: Target) -> str: if callable(target): @@ -224,10 +255,23 @@ def python_code(self, root_module: str) -> str: free_vars: List[str] = [] modules_used : Set[str] = set() body: List[str] = [] + maybe_return_annotation : str = '' + + def register_modules_used(qualified_name : str): + if '.' in qualified_name: + module_name = qualified_name.split('.', maxsplit=1)[0] + modules_used.add(module_name) + + def type_repr(o : Any): + typename = _type_repr(o) + register_modules_used(typename) + return typename + for node in self._nodes: if node.op == 'placeholder': assert isinstance(node.target, str) - free_vars.append(node.target) + maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}' + free_vars.append(f'{node.target}{maybe_type_annotation}') raw_name = node.target.replace('*', '') if raw_name != node.name: body.append(f'{node.name} = {raw_name}\n') @@ -246,9 +290,7 @@ def python_code(self, root_module: str) -> str: body.append(f'{node.name} = {magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}\n') continue qualified_name = _qualified_name(node.target) - if '.' in qualified_name: - module_name = qualified_name.split('.', maxsplit=1)[0] - modules_used.add(module_name) + register_modules_used(qualified_name) if qualified_name == 'getattr' and \ isinstance(node.args, tuple) and \ isinstance(node.args[1], str) and \ @@ -267,6 +309,8 @@ def python_code(self, root_module: str) -> str: body.append(f'{node.name} = {_format_target(root_module, node.target)}\n') continue elif node.op == 'output': + if node.type is not None: + maybe_return_annotation = f" -> {type_repr(node.type)}" body.append(f'return {node.args[0]}') continue raise NotImplementedError(f'node: {node.op} {node.target}') @@ -277,13 +321,17 @@ def python_code(self, root_module: str) -> str: code = '\n'.join(' ' + line for line in code.split('\n')) + '\n' fn_code = f"""\ {import_block} -def forward(self, {', '.join(free_vars)}): +def forward(self, {', '.join(free_vars)}){maybe_return_annotation}: {code} """ + return fn_code def __str__(self) -> str: placeholder_names : List[str] = [] + # This is a one-element array just so `format_node` can modify the closed + # over value + maybe_return_typename : List[str] = [''] def format_arg(arg) -> str: if isinstance(arg, list): @@ -305,20 +353,26 @@ def format_arg(arg) -> str: def format_node(n : Node) -> Optional[str]: if n.op == 'placeholder': assert isinstance(n.target, str) - placeholder_names.append(n.target) + arg_str = n.target + arg_str += arg_str + f': {_type_repr(n.type)}' if n.type is not None else '' + placeholder_names.append(arg_str) return None elif n.op == 'get_attr': - return f'%{n.name} : [#users={len(n.users)}] = self.{n.target}' + maybe_typename = f'{_type_repr(n.type)} ' if n.type is not None else '' + return f'%{n.name} : {maybe_typename}[#users={len(n.users)}] = self.{n.target}' elif n.op == 'output': + if n.type is not None: + maybe_return_typename[0] = f' -> {_type_repr(n.type)}' return f'return {n.args[0]}' else: - return f'%{n.name} : [#users={len(n.users)}] = {n.op}[target={n.target}](' \ + maybe_typename = f'{_type_repr(n.type)} ' if n.type is not None else '' + return f'%{n.name} : {maybe_typename}[#users={len(n.users)}] = {n.op}[target={n.target}](' \ f'args = {format_arg(n.args)}, kwargs = {format_arg(n.kwargs)})' node_strs = [format_node(node) for node in self._nodes] param_str = ', '.join(placeholder_names) - s = f'graph({param_str}):' + s = f'graph({param_str}){maybe_return_typename[0]}:' for node_str in node_strs: if node_str: s += '\n ' + node_str diff --git a/torch/fx/node.py b/torch/fx/node.py index 458e1d3c66a8..7d35483fc5d8 100644 --- a/torch/fx/node.py +++ b/torch/fx/node.py @@ -21,7 +21,8 @@ class Node: def __init__(self, graph: 'Graph', name: str, op: str, target: Target, - args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> None: + args: Tuple[Argument, ...], kwargs: Dict[str, Argument], + type : Optional[Any] = None) -> None: self.graph = graph self.name = name # unique name of value being created assert op in ['placeholder', 'call_method', 'call_module', 'call_function', 'get_attr', 'output'] @@ -39,6 +40,17 @@ def __init__(self, graph: 'Graph', name: str, op: str, target: Target, # # Is a dict to act as an "ordered set". Keys are significant, value dont-care self.users : Dict['Node', None] = {} + # Type expression representing the output value of this node. + # This should contain the same class of Type objects that would appear + # as type annotations for function inputs/outputs. + # + # For placeholder nodes, this value will be used to type-annotate the + # generated function parameters. + # For the return ndoe, this value will be used to type-annotate the + # generated function return type. (Note this is a special case. `return` + # does not produce a value, it's more of a notation. Thus, this value + # describes the type of args[0] in the `return` node. + self.type : Optional[Any] = type @property def args(self) -> Tuple[Argument, ...]: diff --git a/torch/fx/proxy.py b/torch/fx/proxy.py index 90593c4b82f4..20d71781ce1e 100644 --- a/torch/fx/proxy.py +++ b/torch/fx/proxy.py @@ -11,7 +11,8 @@ class TracerBase: graph: Graph def create_node(self, kind : str, target : Union[str, Callable], - args : Tuple[Argument, ...], kwargs : Dict[str, Argument], name : Optional[str] = None) -> Node: + args : Tuple[Argument, ...], kwargs : Dict[str, Argument], name : Optional[str] = None, + type_expr : Optional[Any] = None) -> Node: """ Inserts a graph node given target, args, kwargs, and name. @@ -19,7 +20,7 @@ def create_node(self, kind : str, target : Union[str, Callable], modification of values used in node creation. For example, one might want to disallow in-place operations from being recorded. """ - return self.graph.create_node(kind, target, args, kwargs, name) + return self.graph.create_node(kind, target, args, kwargs, name, type_expr) def create_arg(self, a: Any) -> Argument: """ @@ -65,12 +66,13 @@ class TraceError(ValueError): # Unwrap the proxies inside args, and kwargs, create the resulting node # and then wrap the result in a proxy. -def _create_proxy(tracer: 'TracerBase', op: str, target: Target, args_: Tuple[Any, ...], kwargs_: Dict[str, Any], name=None): +def _create_proxy(tracer: 'TracerBase', op: str, target: Target, args_: Tuple[Any, ...], kwargs_: Dict[str, Any], + name=None, type_expr : Optional[Any] = None): args = tracer.create_arg(args_) kwargs = tracer.create_arg(kwargs_) assert isinstance(args, tuple) assert isinstance(kwargs, dict) - rn = tracer.create_node(op, target, args, kwargs, name) + rn = tracer.create_node(op, target, args, kwargs, name, type_expr) return Proxy(rn, tracer) class Proxy: diff --git a/torch/fx/symbolic_trace.py b/torch/fx/symbolic_trace.py index 7c295f6af133..7c5b2d734ea0 100644 --- a/torch/fx/symbolic_trace.py +++ b/torch/fx/symbolic_trace.py @@ -121,18 +121,23 @@ def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> boo def trace(self, root: torch.nn.Module) -> Graph: self.root = root + fn = type(root).forward self.graph = Graph() - fn = type(root).forward assert isinstance(fn, FunctionType) co = fn.__code__ total_args = co.co_argcount + co.co_kwonlyargcount names_iter = iter(co.co_varnames) next(names_iter) # skip self args : List[Any] = [root] - args.extend(self._proxy_placeholder(next(names_iter)) for name in range(1, total_args)) + + def make_proxy_placeholder(): + name = next(names_iter) + return self._proxy_placeholder(name, fn.__annotations__.get(name, None)) + args.extend(make_proxy_placeholder() for _ in range(1, total_args)) if co.co_kwonlyargcount > 0 or co.co_flags & HAS_VARSTUFF: + # TODO: type annotations for *args and **kwargs if co.co_flags & inspect.CO_VARARGS: args.append(self._proxy_placeholder('*' + next(names_iter))) if co.co_flags & inspect.CO_VARKEYWORDS: @@ -149,13 +154,14 @@ def module_call_wrapper(mod, *args, **kwargs): return _create_proxy(self, 'call_module', module_qualified_name, args, kwargs) try: torch.nn.Module.__call__ = module_call_wrapper - self.create_node('output', 'output', (self.create_arg(fn(*args)),), {}) + self.create_node('output', 'output', (self.create_arg(fn(*args)),), {}, + type_expr=fn.__annotations__.get('return', None)) finally: torch.nn.Module.__call__ = orig_call return self.graph - def _proxy_placeholder(self, name: str) -> Proxy: - return Proxy(self.create_node('placeholder', name, (), {}), self) + def _proxy_placeholder(self, name: str, type_expr: Optional[Any] = None) -> Proxy: + return Proxy(self.create_node('placeholder', name, (), {}, type_expr=type_expr), self) # Symbolic tracing API # From 8d14b50e943c76f3ba32e6d221e069db6909c426 Mon Sep 17 00:00:00 2001 From: Peter Bell Date: Wed, 7 Oct 2020 22:25:38 -0700 Subject: [PATCH 53/69] codegen: Improve array default handing (#45163) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45163 Test Plan: Imported from OSS Reviewed By: ngimel Differential Revision: D24132279 Pulled By: mruberry fbshipit-source-id: 77069e7526b35cf8d13ba448e313c90f20cc67cf --- tools/codegen/api/cpp.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/tools/codegen/api/cpp.py b/tools/codegen/api/cpp.py index 566d8f8265a9..f8fd2fdbde55 100644 --- a/tools/codegen/api/cpp.py +++ b/tools/codegen/api/cpp.py @@ -152,7 +152,6 @@ def returns_type(rs: Sequence[Return]) -> str: 'None': 'c10::nullopt', # UGH this one is type directed 'Mean': 'at::Reduction::Mean', '[]': '{}', - '[0,1]': '{0,1}', # TODO: stop special casing 'contiguous_format': 'MemoryFormat::Contiguous', 'long': 'at::kLong', } @@ -181,6 +180,20 @@ def default_expr(d: str, t: Type) -> str: i += 2 return f'"{s}"' + + if isinstance(t, OptionalType): + if d == 'None': + return 'c10::nullopt' + + return default_expr(d, t.elem) + + if isinstance(t, ListType): + if (d.startswith('[') and d.endswith(']')): + return '{' + d[1:-1] + '}' + elif t.size is None: + # NOTE: Sized lists can have scalar defaults + raise ValueError(f"Expected a list default '[...]' but found: '{d}'") + return JIT_TO_CPP_DEFAULT.get(d, d) # Convert an argument into its C++ API form From b2bff9e4310d184c92e66617c0e396c119d594d1 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Wed, 7 Oct 2020 22:36:27 -0700 Subject: [PATCH 54/69] Workaround for cublas bug for 45724 (#46001) Summary: Fixes https://github.com/pytorch/pytorch/issues/45724 Pull Request resolved: https://github.com/pytorch/pytorch/pull/46001 Reviewed By: mruberry Differential Revision: D24184058 Pulled By: ngimel fbshipit-source-id: 7d2bab3206ddbc10a7cae3efd9b5e253f38400a9 --- aten/src/THC/THCBlas.cu | 54 +++++++++++++++++++++++++++++++++++++++-- test/test_torch.py | 10 ++++++++ 2 files changed, 62 insertions(+), 2 deletions(-) diff --git a/aten/src/THC/THCBlas.cu b/aten/src/THC/THCBlas.cu index 859d904a582b..3f16eec6df60 100644 --- a/aten/src/THC/THCBlas.cu +++ b/aten/src/THC/THCBlas.cu @@ -133,6 +133,56 @@ void THCudaBlas_Dgemm(THCState *state, char transa, char transb, int64_t m, int6 at::cuda::blas::gemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } +#ifndef __HIP_PLATFORM_HCC__ +#if defined(CUDA_VERSION) && CUDA_VERSION >= 11200 +#define cublasGemmStridedBatchedExFix cublasGemmStridedBatchedEx +#else +// Workaround for https://github.com/pytorch/pytorch/issues/45724 +cublasStatus_t cublasGemmStridedBatchedExFix(cublasHandle_t &handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const void *alpha, + const void *A, + cudaDataType Atype, + int lda, + long long int strideA, + const void *B, + cudaDataType Btype, + int ldb, + long long int strideB, + const void *beta, + void *C, + cudaDataType Ctype, + int ldc, + long long int strideC, + int64_t batchCount, + cudaDataType computeType, + cublasGemmAlgo_t algo) +{ + cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties(); + if (prop->major != 7) { + return cublasGemmStridedBatchedEx(handle, transa, transb, m, n, k, alpha, A, Atype, lda, strideA, B, Btype, ldb, strideB, beta, C, Ctype, ldc, strideC, batchCount, computeType, algo); + } + cublasStatus_t result; + constexpr int64_t split = 63 * 1024; + for(int64_t i = 0; i < batchCount; i += split) { + int64_t count = std::min(split, batchCount - i); + result = cublasGemmStridedBatchedEx(handle, transa, transb, m, n, k, alpha, + (char *)A + i * strideA * 2, Atype, lda, strideA, + (char *)B + i * strideB * 2, Btype, ldb, strideB, + beta, + (char *)C + i * strideC * 2, Ctype, ldc, strideC, + (int)count, computeType, algo); + THCublasCheck(result); + } + return result; +} +#endif +#endif + void THCudaBlas_HgemmStridedBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k, at::Half alpha, const at::Half *a, int64_t lda, int64_t strideA, const at::Half *b, int64_t ldb, int64_t strideB, at::Half beta, at::Half *c, int64_t ldc, int64_t strideC, int64_t batchCount) @@ -167,7 +217,7 @@ void THCudaBlas_HgemmStridedBatched(THCState *state, char transa, char transb, i // manually to be able to use tensor cores for FP16. On CUDA 11, this is no longer required. THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); #endif // CUDA_VERSION < 11000 - THCublasCheck(cublasGemmStridedBatchedEx(handle, + THCublasCheck(cublasGemmStridedBatchedExFix(handle, opa, opb, (int)m, (int)n, (int)k, (void*)&fAlpha, a, CUDA_R_16F, (int)lda, strideA, b, CUDA_R_16F, (int)ldb, strideB, @@ -207,7 +257,7 @@ void THCudaBlas_BgemmStridedBatched(THCState *state, char transa, char transb, i if (prop->major < 8) { TORCH_CHECK(false, "BFloat16 gemm in CUDA requires Ampere or later GPU"); } - THCublasCheck(cublasGemmStridedBatchedEx(handle, + THCublasCheck(cublasGemmStridedBatchedExFix(handle, opa, opb, (int)m, (int)n, (int)k, (void*)&fAlpha, a, CUDA_R_16BF, (int)lda, strideA, b, CUDA_R_16BF, (int)ldb, strideB, diff --git a/test/test_torch.py b/test/test_torch.py index 312943d8715c..39f2f925bd11 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -16814,6 +16814,16 @@ def test_addmm_sizes(self, device, dtype): m2 = torch.randn(k, m, device=device).to(dtype) self._test_addmm_addmv(torch.addmm, M, m1, m2) + @onlyCUDA + def test_matmul_45724(self, device): + # https://github.com/pytorch/pytorch/issues/45724 + a = torch.rand(65537, 22, 64).cuda().half() + b = torch.rand(65537, 64, 22).cuda().half() + c = torch.full((65537, 22, 22), math.nan, dtype=torch.half, device='cuda') + cpu_result = torch.matmul(a.cpu().float(), b.cpu().float()).cuda().half() + torch.matmul(a, b, out=c) + self.assertEqual(c, cpu_result) + def _test_dot_vdot_vs_numpy(self, device, dtype, torch_fn, np_fn): def compare_with_numpy_bin_op(torch_fn, np_fn, x, y): y_np = y.cpu().numpy() From ef4817fe5a16ba9969562911c5363736a1003bb0 Mon Sep 17 00:00:00 2001 From: Kurt Mohler Date: Wed, 7 Oct 2020 23:12:41 -0700 Subject: [PATCH 55/69] Add `tensor_split` function, based on `numpy.array_split` (#45168) Summary: Fixes https://github.com/pytorch/pytorch/issues/9382 Pull Request resolved: https://github.com/pytorch/pytorch/pull/45168 Reviewed By: ngimel Differential Revision: D24166164 Pulled By: mruberry fbshipit-source-id: 795459821e52885bc99623a01a2abec060995ce6 --- aten/src/ATen/BatchingRegistrations.cpp | 18 ++++ aten/src/ATen/core/NamedRegistrations.cpp | 2 + aten/src/ATen/core/aten_interned_strings.h | 1 + aten/src/ATen/native/TensorShape.cpp | 31 +++++++ aten/src/ATen/native/native_functions.yaml | 8 ++ docs/source/tensors.rst | 1 + docs/source/torch.rst | 1 + test/test_autograd.py | 2 +- test/test_jit.py | 6 ++ test/test_torch.py | 88 ++++++++++++++++++- test/test_vmap.py | 21 +++++ tools/autograd/gen_autograd.py | 1 + torch/_tensor_docs.py | 7 ++ torch/_torch_docs.py | 61 +++++++++++++ torch/overrides.py | 1 + .../_internal/common_methods_invocations.py | 4 + 16 files changed, 251 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/BatchingRegistrations.cpp b/aten/src/ATen/BatchingRegistrations.cpp index e930ffd7e2ea..e0f4d11bca54 100644 --- a/aten/src/ATen/BatchingRegistrations.cpp +++ b/aten/src/ATen/BatchingRegistrations.cpp @@ -172,6 +172,22 @@ std::vector chunk_batching_rule(const Tensor& self, int64_t chunks, int6 return result; } +std::vector tensor_split_sections_batching_rule(const Tensor& self, int64_t sections, int64_t dim) { + auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); + auto dim_physical = self_physical.getPhysicalDim(dim); + auto result = at::tensor_split(self_physical.tensor(), sections, dim_physical); + self_physical.makeLogicalFromPhysicalListInplace(result); + return result; +} + +std::vector tensor_split_indices_batching_rule(const Tensor& self, IntArrayRef indices, int64_t dim) { + auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); + auto dim_physical = self_physical.getPhysicalDim(dim); + auto result = at::tensor_split(self_physical.tensor(), indices, dim_physical); + self_physical.makeLogicalFromPhysicalListInplace(result); + return result; +} + Tensor unsqueeze_batching_rule(const Tensor& self, int64_t dim) { auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); // NB: unsqueeze has some special handling of its `dim` argument so we can't call @@ -527,6 +543,8 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) { // view operations m.impl("chunk", chunk_batching_rule); + m.impl("tensor_split.sections", tensor_split_sections_batching_rule); + m.impl("tensor_split.indices", tensor_split_indices_batching_rule); m.impl("diagonal", diagonal_batching_rule); m.impl("expand", expand_batching_rule); m.impl("expand_as", native::expand_as); // composite wrt autograd diff --git a/aten/src/ATen/core/NamedRegistrations.cpp b/aten/src/ATen/core/NamedRegistrations.cpp index 6712be56ebb2..640ad0c181e4 100644 --- a/aten/src/ATen/core/NamedRegistrations.cpp +++ b/aten/src/ATen/core/NamedRegistrations.cpp @@ -453,6 +453,8 @@ TORCH_LIBRARY_IMPL(aten, Named, m) { m.impl("tanh", CppFunction::makeFallthrough()); m.impl("tanh.out", CppFunction::makeFallthrough()); m.impl("tanh_", CppFunction::makeFallthrough()); + m.impl("tensor_split.indices", CppFunction::makeFallthrough()); + m.impl("tensor_split.sections", CppFunction::makeFallthrough()); m.impl("threshold", CppFunction::makeFallthrough()); m.impl("threshold.out", CppFunction::makeFallthrough()); m.impl("threshold_", CppFunction::makeFallthrough()); diff --git a/aten/src/ATen/core/aten_interned_strings.h b/aten/src/ATen/core/aten_interned_strings.h index 54481814be5b..3c82ecdc48c0 100644 --- a/aten/src/ATen/core/aten_interned_strings.h +++ b/aten/src/ATen/core/aten_interned_strings.h @@ -664,6 +664,7 @@ _(aten, tan) \ _(aten, tanh) \ _(aten, tensor) \ _(aten, tensordot) \ +_(aten, tensor_split) \ _(aten, th_addmm) \ _(aten, th_clone) \ _(aten, th_norm) \ diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index 7fba7916354a..5aac8e9a1715 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -513,6 +513,37 @@ std::vector chunk(const Tensor& self, int64_t chunks, int64_t dim) { } } +std::vector tensor_split(const Tensor& self, int64_t sections, int64_t dim) { + TORCH_CHECK(self.dim() > 0, "expected at least a 1-dimensional tensor"); + int64_t dim_ = maybe_wrap_dim(dim, self.dim()); + TORCH_CHECK(sections > 0, "number of sections must be larger than 0, got ", sections); + std::vector splits(sections); + int64_t min_split_size = self.size(dim_) / sections; + int64_t num_splits_one_extra = self.size(dim_) % sections; + int64_t start_idx = 0; + for (int64_t split_idx = 0; split_idx < sections; split_idx++) { + int64_t split_size = (split_idx < num_splits_one_extra) ? (min_split_size + 1) : min_split_size; + splits[split_idx] = at::slice(self, dim_, start_idx, start_idx + split_size); + start_idx += split_size; + } + return splits; +} + +std::vector tensor_split(const Tensor& self, IntArrayRef indices, int64_t dim) { + TORCH_CHECK(self.dim() > 0, "expected at least a 1-dimensional tensor"); + int64_t dim_ = maybe_wrap_dim(dim, self.dim()); + int64_t num_indices = indices.size(); + std::vector splits(num_indices + 1); + int64_t start_idx = 0; + for (int64_t split_idx = 0; split_idx < num_indices; split_idx++) { + int64_t end_idx = indices[split_idx]; + splits[split_idx] = at::slice(self, dim_, start_idx, end_idx); + start_idx = end_idx; + } + splits[num_indices] = at::slice(self, dim_, start_idx, self.size(dim_)); + return splits; +} + std::vector unsafe_chunk(const Tensor& self, int64_t chunks, int64_t dim) { TORCH_CHECK(self.dim() > 0, "chunk expects at least a 1-dimensional tensor"); diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index e64a66a07417..8559eb8bdc60 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -912,6 +912,14 @@ variants: function, method device_guard: False +- func: tensor_split.sections(Tensor(a) self, int sections, int dim=0) -> Tensor(a)[] + use_c10_dispatcher: full + variants: function, method + +- func: tensor_split.indices(Tensor(a) self, int[] indices, int dim=0) -> Tensor(a)[] + use_c10_dispatcher: full + variants: function, method + - func: clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor use_c10_dispatcher: full variants: function, method diff --git a/docs/source/tensors.rst b/docs/source/tensors.rst index 94b1fb25f58e..ef6d60599d7f 100644 --- a/docs/source/tensors.rst +++ b/docs/source/tensors.rst @@ -573,6 +573,7 @@ view of a storage and defines numeric operations on it. .. automethod:: symeig .. automethod:: t .. automethod:: t_ + .. automethod:: tensor_split .. automethod:: to .. automethod:: to_mkldnn .. automethod:: take diff --git a/docs/source/torch.rst b/docs/source/torch.rst index d0537947d4ff..bb267810cf41 100644 --- a/docs/source/torch.rst +++ b/docs/source/torch.rst @@ -99,6 +99,7 @@ Indexing, Slicing, Joining, Mutating Ops stack t take + tensor_split transpose unbind unsqueeze diff --git a/test/test_autograd.py b/test/test_autograd.py index 6bd6925e015f..3c0d0a9a2e8e 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -4833,7 +4833,7 @@ def run_functional_checks(test_case, test_name, name, apply_fn, run_grad_checks, 'permute', 'squeeze', 'unsqueeze', 'resize', 'resize_as', 'tril', 'triu', 'chunk', 'split', 'split_with_sizes', 'repeat', 'expand', 'zero_', 'eq_', 'ne_', 'add', '__radd__', 'sum', 'conj', 'sin', 'cos', 'mul', 'sinh', - 'cosh', '__rmul__', 'sgn', 'abs', 'dot', 'vdot'] + separate_complex_tests + 'cosh', '__rmul__', 'sgn', 'abs', 'dot', 'vdot', 'tensor_split'] + separate_complex_tests # TODO(@anjali411): add tests for 'sub', 'div # TODO(@anjali411): add the commented tests back after updating the formula based on tensorflow definition - @anjali411 diff --git a/test/test_jit.py b/test/test_jit.py index 797904d2bf20..01bf1339bcf7 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -15424,6 +15424,12 @@ def fn(x): 'test_split_size_list', 'test_split_size_list_dim', 'test_split_size_list_dim_neg0', + 'test_tensor_indices_sections', + 'test_tensor_indices_sections_dim', + 'test_tensor_indices_sections_dim_neg0', + 'test_tensor_split_sections', + 'test_tensor_split_sections_dim', + 'test_tensor_split_sections_dim_neg0', } EXCLUDE_PYTHON_PRINT = { diff --git a/test/test_torch.py b/test/test_torch.py index 39f2f925bd11..ef40a54d4eee 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -32,7 +32,7 @@ do_test_dtypes, IS_SANDCASTLE, load_tests, slowTest, skipCUDANonDefaultStreamIf, skipCUDAMemoryLeakCheckIf, BytesIOContext, skipIfRocm, torch_to_numpy_dtype_dict, skipIfNoSciPy, IS_MACOS, IS_PPC, - wrapDeterministicFlagAPITest) + wrapDeterministicFlagAPITest, make_tensor) from multiprocessing.reduction import ForkingPickler from torch.testing._internal.common_device_type import instantiate_device_type_tests, \ skipCPUIfNoLapack, skipCUDAIfNoMagma, skipCUDAIfRocm, skipCUDAIfNotRocm, onlyCUDA, onlyCPU, \ @@ -8346,6 +8346,92 @@ def test_contiguous(self, device): x.set_(x.storage(), 0, x.size(), stride) self.assertTrue(x.is_contiguous()) + @onlyOnCPUAndCUDA + # Skip BFloat16 since numpy does not support it + @dtypes(*torch.testing.get_all_dtypes(include_bfloat16=False)) + def test_tensor_split_sections(self, device, dtype): + input_sizes = [ + (0,), + (10,), + (10, 0), + (0, 10), + (4, 10), + (12, 3), + ] + for input_size in input_sizes: + a_base = make_tensor(input_size, device, dtype, low=-9, high=9) + # Run tests on transposed input if it has at least 2 dims + for a in [a_base, a_base.t()] if a_base.dim() > 2 else [a_base]: + a_n = a.cpu().numpy() + for dim in range(-a.dim(), a.dim()): + for sections in range(1, 2 * a.size(dim)): + msg = f'input_size {input_size}, sections {sections}, dim {dim}' + result = torch.tensor_split(a, sections, dim) + for result_item in result: + self.assertEqual(result_item.device, torch.device(device), msg=msg) + self.assertEqual(result_item.dtype, dtype, msg=msg) + result_n = np.array_split(a_n, sections, dim) + self.assertEqual(result_n, result, msg=msg) + + @onlyOnCPUAndCUDA + # Skip BFloat16 since numpy does not support it + @dtypes(*torch.testing.get_all_dtypes(include_bfloat16=False)) + def test_tensor_split_indices(self, device, dtype): + input_sizes = [ + (0,), + (10,), + (10, 0), + (0, 10), + (4, 10), + (12, 3), + ] + indices_args = [ + (), + (0,), + (3,), + (10,), + (-1,), + (-10,), + (2, -1), + (3, 4, 10), + (0, -1, 0, 10), + (1, 5, 2, 8), + ] + for input_size in input_sizes: + a_base = make_tensor(input_size, device, dtype, low=-9, high=9) + # Run tests on transposed input if it has at least 2 dims + for a in [a_base, a_base.t()] if a_base.dim() > 2 else [a_base]: + a_n = a.cpu().numpy() + for dim in range(-a.dim(), a.dim()): + for indices in indices_args: + result = torch.tensor_split(a, indices, dim) + msg = f'input_size {input_size}, indices {indices}, dim {dim}' + for result_item in result: + self.assertEqual(result_item.device, torch.device(device), msg=msg) + self.assertEqual(result_item.dtype, dtype, msg=msg) + result_n = np.array_split(a_n, indices, dim) + self.assertEqual(result_n, result, msg=msg) + + @onlyOnCPUAndCUDA + def test_tensor_split_errors(self, device): + S = 10 + test_cases = [ + # input size, sections or indices, dim, error type, error message, numpy error type + [(S,), 10, 1, IndexError, r'Dimension out of range', IndexError], + [(), 10, 0, RuntimeError, r'expected at least a 1-dimensional tensor', IndexError], + [(S,), (10,), 1, IndexError, r'Dimension out of range', IndexError], + [(), (10,), 0, RuntimeError, r'expected at least a 1-dimensional tensor', IndexError], + [(S,), 0, 0, RuntimeError, r'number of sections must be larger than 0, got 0', ValueError], + [(S,), -1, 0, RuntimeError, r'number of sections must be larger than 0, got -1', ValueError], + ] + for input_size, sections_or_indices, dim, err, err_msg, numpy_err in test_cases: + a = torch.randn(input_size, device=device) + msg = f'input_size {input_size}, sections_or_indices {sections_or_indices}, dim {dim}' + with self.assertRaisesRegex(err, err_msg, msg=msg): + torch.tensor_split(a, sections_or_indices, dim) + with self.assertRaises(numpy_err, msg=msg): + np.array_split(a.cpu().numpy(), sections_or_indices, dim) + def test_index(self, device): def consec(size, start=1): diff --git a/test/test_vmap.py b/test/test_vmap.py index abec2c0ae489..39775a816eff 100644 --- a/test/test_vmap.py +++ b/test/test_vmap.py @@ -1237,6 +1237,27 @@ def wrapped(*args, **kwargs): test(op, (torch.randn(B0, 2), torch.randint(10, [B0], dtype=torch.int64)), check_propagates_grad=False) + def test_tensor_split(self): + test = self._vmap_view_test + op = torch.tensor_split + B0, B1, B2 = 7, 11, 13 + + # tests for torch.tensor_split(self, indices_or_sections: int, dim) + test(op, (torch.rand(B0, 2, 1024), 5, -1), in_dims=(0, None, None)) + test(op, (torch.rand(2, B0, 1024), 150, 1), in_dims=(1, None, None)) + test(vmap(op, in_dims=(0, None, None)), (torch.rand(B1, 1023, B0, 5), 256, 0), + in_dims=(2, None, None)) + test(vmap(vmap(lambda t: op(t, 4, 1), in_dims=2)), + (torch.rand(B1, 2, B0, 64, B2),), in_dims=2) + + # tests for torch.tensor_split(self, indices_or_sections: List[int], dim) + test(op, (torch.rand(B0, 2, 1024), [50, 100, 378, 890], -1), in_dims=(0, None, None)) + test(op, (torch.rand(2, B0, 1024), [50, 100, 212, 345, 0, 378, 890], 1), in_dims=(1, None, None)) + test(vmap(op, in_dims=(0, None, None)), (torch.rand(B1, 1023, B0, 5), [50, 100, 212, 345, 0, 378, 890], 0), + in_dims=(2, None, None)) + test(vmap(vmap(lambda t: op(t, [4, 8, 9, 34, 29], 1), in_dims=2)), + (torch.rand(B1, 2, B0, 64, B2),), in_dims=2) + def test_split(self): test = self._vmap_view_test op = torch.split diff --git a/tools/autograd/gen_autograd.py b/tools/autograd/gen_autograd.py index c12e9b2003d8..3e303c906b89 100644 --- a/tools/autograd/gen_autograd.py +++ b/tools/autograd/gen_autograd.py @@ -87,6 +87,7 @@ RETURNS_VIEWS_OF_INPUT = set(VIEW_FUNCTIONS.keys()).union({ 'chunk', 'detach', 'contiguous', 'reshape', 'reshape_as', 'expand_as', 'view_as', 'real', 'imag', 'narrow', 'movedim', + 'tensor_split' }) def format_return_type(returns): diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py index 7caceff4a1d1..ffa4ffe6100a 100644 --- a/torch/_tensor_docs.py +++ b/torch/_tensor_docs.py @@ -4045,6 +4045,13 @@ def callable(a, b) -> number See :func:`torch.unsafe_split` """) +add_docstr_all('tensor_split', + r""" +tensor_split(indices_or_sections, dim=0) -> List of Tensors + +See :func:`torch.tensor_split` +""") + add_docstr_all('stft', r""" stft(frame_length, hop, fft_size=None, return_onesided=True, window=None, pad_end=0) -> Tensor diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 4ad620d4abd7..609cd34b2e95 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -1209,6 +1209,67 @@ def merge_dicts(*dicts): """.format(**common_args)) +add_docstr(torch.tensor_split, + r""" +tensor_split(input, indices_or_sections, dim=0) -> List of Tensors + +Splits a tensor into multiple sub-tensors, all of which are views of :attr:`input`, +along dimension :attr:`dim` according to the indices or number of sections specified +by :attr:`indices_or_sections. This function is based on NumPy's +:func:`numpy.array_split`. + +Args: + input (Tensor): the tensor to split + indices_or_sections (int or (list(int))): + If :attr:`indices_or_sections` is an integer ``n``, :attr:`input` is split + into ``n`` sections along dimension :attr:`dim`. If :attr:`input` is divisible + by ``n`` along dimension :attr:`dim`, each section will be of equal size, + :code:`input.size(dim) / n`. If :attr:`input` is not divisible by ``n``, the + sizes of the first :code:`int(input.size(dim) % n)` sections will have size + :code:`int(input.size(dim) / n) + 1`, and the rest will have size + :code:`int(input.size(dim) / n)`. + + If :attr:`indices_or_sections` is a list of ints, :attr:`input` is split along + dimension :attr:`dim` at each of the indices in the list. For instance, + :code:`[2, 3]` and :code:`dim=0` would result in the following tensors: + + - :code:`input[:2]` + - :code:`input[2:3]` + - :code:`input[3:]` + + dim (int, optional): dimension along which to split the tensor. Default: ``0`` + +Example:: + >>> x = torch.arange(8) + >>> torch.tensor_split(x, 3) + (tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7])) + + >>> x = torch.arange(7) + >>> torch.tensor_split(x, 3) + (tensor([0, 1, 2]), tensor([3, 4]), tensor([5, 6])) + >>> torch.tensor_split(x, (1, 6)) + (tensor([0]), tensor([1, 2, 3, 4, 5]), tensor([6])) + + >>> x = torch.arange(14).reshape(2, 7) + >>> x + tensor([[ 0, 1, 2, 3, 4, 5, 6], + [ 7, 8, 9, 10, 11, 12, 13]]) + >>> torch.tensor_split(x, 3, dim=1) + (tensor([[0, 1, 2], + [7, 8, 9]]), + tensor([[ 3, 4], + [10, 11]]), + tensor([[ 5, 6], + [12, 13]])) + >>> torch.tensor_split(x, (1, 6), dim=1) + (tensor([[0], + [7]]), + tensor([[ 1, 2, 3, 4, 5], + [ 8, 9, 10, 11, 12]]), + tensor([[ 6], + [13]])) +""") + add_docstr(torch.chunk, r""" chunk(input, chunks, dim=0) -> List of Tensors diff --git a/torch/overrides.py b/torch/overrides.py index 43efda1da862..dc434e9c1f58 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -742,6 +742,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.tan: lambda input, out=None: -1, torch.tanh: lambda input, out=None: -1, torch.tensordot: lambda a, b, dims=2: -1, + torch.tensor_split: lambda input, indices_or_sections, dim=0: -1, torch.threshold: lambda input, threshold, value, inplace=False: -1, torch.topk: lambda input, k, dim=-1, descending=False, out=None: -1, torch.trace: lambda input: -1, diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index a6887395c19a..2a99ae643931 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -1321,6 +1321,10 @@ def method_tests(): ('split_with_sizes', (S, S, S), ([int(S / 3), S - int(S / 3) * 2, int(S / 3)],), '', (True,)), ('split_with_sizes', (S, S, S), ([int(S / 3), S - int(S / 3), 0],), 'size_0', (True, )), ('split_with_sizes', (S, S, S), ([int(S / 3), S - int(S / 3) * 2, int(S / 3)],), 'dim', (True, ), [1]), + ('tensor_split', (S, S, S), (3,), 'sections', (False,)), + ('tensor_split', (S, S, S), (3, 1), 'sections_dim', (False,), [1]), + ('tensor_split', (S, S, S), ([2, 4],), 'indices', (False,)), + ('tensor_split', (S, S, S), ([2, 4], 1), 'indices_dim', (False,), [1]), ('gather', (M, S), (0, gather_variable((S, S), 1, M, True)), 'dim0', (), [0]), ('gather', (M, S), (1, gather_variable((M, S // 2), 0, S, True)), 'dim1', (), [0]), ('gather', (), (0, torch.tensor([0], dtype=torch.int64)), 'scalar_input', (), [0]), From c19b9cd18dd1da5f499ef0672e6871928618204d Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Wed, 7 Oct 2020 23:54:48 -0700 Subject: [PATCH 56/69] Add torch::cuda::ncll::all2all (#45900) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45900 Use `torch:cuda::nccl:all2all` from `ProcesGroupNCCL.cpp` Fixes https://github.com/pytorch/pytorch/issues/42517 Here is a NCCL dependency graph: ``` libnccl.a --> libtorch_cuda.so ---> libtorch_python.so | ^ | | --------> libc10d.a ----------------- ``` When static library is linked into a dynamic library or an executable, linker is removes all unused/duplicate symbols from that library, unless `-whole-archive` option is used. Before https://github.com/pytorch/pytorch/pull/42514 all nccl call made from `ProcessGroupNCCL.cpp` were also made from `torch/csrc/cuda/nccl.cpp`, which is compiled as part of `libtorch_cuda.so` But adding `ncclSend`|`ncclRecv` to ProcesGroupNCCL.cpp forced linker to embed those into `libtorch_python.so`, which also resulted in linking other dependent symbols into the library. This PR adds `nccl[Send|Recv]` call to `torch_cuda.so` by implementing `all2all` in `torch_cuda` and thus avoids double linking the static library. More involved, but prone solution, would be to use wrappers exported in `torch::cuda::nccl` namespace, instead of making direct NCCL API calls. Test Plan: Imported from OSS Reviewed By: mingzhe09088 Differential Revision: D24138011 Pulled By: malfet fbshipit-source-id: 33305197fc7d8707b7fd3a66b543f7733b9241a1 --- torch/csrc/cuda/comm.cpp | 2 +- torch/csrc/cuda/nccl.cpp | 36 ++++++++++++++++++++++++++ torch/csrc/cuda/nccl.h | 7 ++++++ torch/lib/c10d/NCCLUtils.hpp | 3 --- torch/lib/c10d/ProcessGroupNCCL.cpp | 39 ++++++----------------------- 5 files changed, 51 insertions(+), 36 deletions(-) diff --git a/torch/csrc/cuda/comm.cpp b/torch/csrc/cuda/comm.cpp index ca341305ec1d..1f85b0e1eba5 100644 --- a/torch/csrc/cuda/comm.cpp +++ b/torch/csrc/cuda/comm.cpp @@ -130,7 +130,7 @@ std::vector broadcast(const Tensor& tensor, IntArrayRef devices) { // When splitting, the view operations will make all Variables broadcast // together to share a single version counter, because they are all views of the // large Variable. However, that large Variable is immediately discarded and all -// these Varaibles do not share storage at all. +// these Variables do not share storage at all. // // For example, when two buffers are broadcast together in `DataParallel` and // one of them is modified in-place during `forward` but the other is needed in diff --git a/torch/csrc/cuda/nccl.cpp b/torch/csrc/cuda/nccl.cpp index 780b129ab922..8b05caea5aba 100644 --- a/torch/csrc/cuda/nccl.cpp +++ b/torch/csrc/cuda/nccl.cpp @@ -624,6 +624,42 @@ void all_gather( AT_ERROR("PyTorch built without NCCL support"); #endif } + +void all2all(at::Tensor& input, + at::Tensor& output, + int size, + ncclComm_t _comm, + at::cuda::CUDAStream& stream) { +#ifdef USE_NCCL +#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && (NCCL_MAJOR * 10 + NCCL_MINOR) >= 27 + using namespace torch::cuda::nccl::detail; + + int numranks; + auto type = to_nccl_data_type(input); + size_t count = input.numel() / size; + size_t rankdiff = input.nbytes() / size; + const auto* sendbuff = reinterpret_cast(input.data_ptr()); + auto* recvbuff = reinterpret_cast(output.data_ptr()); + auto comm = to_nccl_comm(_comm); + NCCL_CHECK(ncclCommCount(comm, &numranks)); + NCCL_CHECK(ncclGroupStart()); + for (int r = 0; r < numranks; r++) { + // NCCL uses 0 byte message for synchronization + // Avoid send/recv when message size is zero + if (count != 0) { + NCCL_CHECK(ncclSend(sendbuff + r * rankdiff, count, type, r, comm, stream)); + NCCL_CHECK(ncclRecv(recvbuff + r * rankdiff, count, type, r, comm, stream)); + } + } + NCCL_CHECK(ncclGroupEnd()); +#else + AT_ERROR("all2all is only supported for NCCL lib version >= 2.7.0"); +#endif +#else + AT_ERROR("PyTorch built without NCCL support"); +#endif +} + } // namespace nccl } // namespace cuda } // namespace torch diff --git a/torch/csrc/cuda/nccl.h b/torch/csrc/cuda/nccl.h index 3550cf70aa58..ecf854ec2009 100644 --- a/torch/csrc/cuda/nccl.h +++ b/torch/csrc/cuda/nccl.h @@ -136,6 +136,13 @@ TORCH_CUDA_API void all_gather( const stream_list& streams = {}, const comm_list& user_comms = {}); +TORCH_CUDA_API void all2all( + at::Tensor& input, + at::Tensor& output, + int size, + ncclComm_t comm, + at::cuda::CUDAStream& stream); + } // namespace nccl } // namespace cuda } // namespace torch diff --git a/torch/lib/c10d/NCCLUtils.hpp b/torch/lib/c10d/NCCLUtils.hpp index 433a71ef92d7..804667e28081 100644 --- a/torch/lib/c10d/NCCLUtils.hpp +++ b/torch/lib/c10d/NCCLUtils.hpp @@ -17,8 +17,6 @@ #define ENABLE_NCCL_ERROR_CHECKING #endif -// Fix build issues with NCCL P2P - until then disable NCCL send/recv. -#if defined(ENABLE_NCCL_A2A) && (ENABLE_NCCL_A2A == 1) // P2P is enabled only for NCCL versions 2.7+ since ncclSend() // and ncclRecv() are not supported in earlier versions. #if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \ @@ -27,7 +25,6 @@ #elif defined(NCCL_MAJOR) && (NCCL_MAJOR >= 3) #define ENABLE_NCCL_P2P_SUPPORT #endif -#endif // Macro to throw on a non-successful NCCL return value. #define C10D_NCCL_CHECK(cmd) \ diff --git a/torch/lib/c10d/ProcessGroupNCCL.cpp b/torch/lib/c10d/ProcessGroupNCCL.cpp index 6e45b8594f9b..4b687ab51c1a 100644 --- a/torch/lib/c10d/ProcessGroupNCCL.cpp +++ b/torch/lib/c10d/ProcessGroupNCCL.cpp @@ -8,6 +8,7 @@ #include #include +#include #include namespace c10d { @@ -165,31 +166,6 @@ std::string getNcclAbortedCommStoreKey(const std::string ncclIdStr) { } #ifdef ENABLE_NCCL_P2P_SUPPORT -ncclResult_t ncclAlltoall( - void* sendbuff, - void* recvbuff, - size_t count, - size_t size, - ncclDataType_t type, - ncclComm_t comm, - cudaStream_t stream) { - int numranks; - size_t rankdiff = count * size; - C10D_NCCL_CHECK(ncclCommCount(comm, &numranks)); - C10D_NCCL_CHECK(ncclGroupStart()); - for (int r = 0; r < numranks; r++) { - // NCCL uses 0 byte message for synchronization - // Avoid send/recv when message size is zero - if (count != 0) { - C10D_NCCL_CHECK(ncclSend( - ((char*)sendbuff) + r * rankdiff, count, type, r, comm, stream)); - C10D_NCCL_CHECK(ncclRecv( - ((char*)recvbuff) + r * rankdiff, count, type, r, comm, stream)); - } - } - C10D_NCCL_CHECK(ncclGroupEnd()); - return ncclSuccess; -} ncclResult_t ncclAlltoallv( void* sendbuff, @@ -1386,14 +1362,13 @@ std::shared_ptr ProcessGroupNCCL::alltoall_base( at::Tensor& output, ncclComm_t comm, at::cuda::CUDAStream& stream) { - return ncclAlltoall( - input.data_ptr(), - output.data_ptr(), - input.numel() / size_, - input.element_size(), - getNcclDataType(input.scalar_type()), + torch::cuda::nccl::all2all( + input, + output, + this->getSize(), comm, - stream.stream()); + stream); + return ncclSuccess; }); } else { c10d::checkSplitSizes(inputSplitSizes, inputTensor, size_); From 99d3f37bd4764b68f16f2de48ff1efd099d3e457 Mon Sep 17 00:00:00 2001 From: Peter Bell Date: Thu, 8 Oct 2020 00:00:20 -0700 Subject: [PATCH 57/69] Run gradgradcheck on torch.fft transforms (#46004) Summary: Ref https://github.com/pytorch/pytorch/issues/42175 As already noted in the `torch.fft` `gradcheck` tests, `gradcheck` isn't fully working for complex types yet and the function inputs need to be real. A similar workaround for `gradgradcheck` works, viewing the complex outputs as real before returning them makes `gradgradcheck` pass. Pull Request resolved: https://github.com/pytorch/pytorch/pull/46004 Reviewed By: ngimel Differential Revision: D24187000 Pulled By: mruberry fbshipit-source-id: 33c2986b07bac282dff1bd4f2109beb70e47bf79 --- test/test_spectral_ops.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/test/test_spectral_ops.py b/test/test_spectral_ops.py index 82ed2225bda8..8bd7249f3425 100644 --- a/test/test_spectral_ops.py +++ b/test/test_spectral_ops.py @@ -6,7 +6,7 @@ import itertools from torch.testing._internal.common_utils import \ - (TestCase, run_tests, TEST_NUMPY, TEST_LIBROSA) + (TestCase, run_tests, TEST_NUMPY, TEST_LIBROSA, _assertGradAndGradgradChecks) from torch.testing._internal.common_device_type import \ (instantiate_device_type_tests, dtypes, onlyOnCPUAndCUDA, precisionOverride, skipCPUIfNoMkl, skipCUDAIfRocm, deviceCountAtLeast, onlyCUDA) @@ -341,14 +341,16 @@ def test_fft_backward(self, device, dtype): # Use real input instead and put view_as_complex into the graph if dtype.is_complex: def test_fn(x): - return torch_fn(torch.view_as_complex(x), *args) - input = torch.view_as_real(input).detach().requires_grad_() + out = torch_fn(torch.view_as_complex(x), *args) + return torch.view_as_real(out) if out.is_complex() else out + inputs = (torch.view_as_real(input).detach().requires_grad_(),) else: def test_fn(x): - return torch_fn(x, *args) - input = input.detach().requires_grad_() + out = torch_fn(x, *args) + return torch.view_as_real(out) if out.is_complex() else out + inputs = (input.detach().requires_grad_(),) - self.assertTrue(torch.autograd.gradcheck(test_fn, (input,))) + _assertGradAndGradgradChecks(self, test_fn, inputs) # nd-fft tests @@ -473,14 +475,16 @@ def test_fftn_backward(self, device, dtype): # Use real input instead and put view_as_complex into the graph if dtype.is_complex: def test_fn(x): - return torch_fn(torch.view_as_complex(x), s, dim, norm) + out = torch_fn(torch.view_as_complex(x), s, dim, norm) + return torch.view_as_real(out) if out.is_complex() else out inputs = (torch.view_as_real(input).detach().requires_grad_(),) else: def test_fn(x): - return torch_fn(x, s, dim, norm) + out = torch_fn(x, s, dim, norm) + return torch.view_as_real(out) if out.is_complex() else out inputs = (input.detach().requires_grad_(),) - self.assertTrue(torch.autograd.gradcheck(test_fn, inputs)) + _assertGradAndGradgradChecks(self, test_fn, inputs) @skipCUDAIfRocm @skipCPUIfNoMkl From a92b49f7c8f6642e61ae6155c6b18962d4bd9086 Mon Sep 17 00:00:00 2001 From: Yinghai Lu Date: Thu, 8 Oct 2020 00:16:29 -0700 Subject: [PATCH 58/69] [Onnxifi] Don't throw exception when we cannot write out debug files (#45979) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45979 For some reason, sometime we cannot write out the debug files. This shouldn't block the whole service. Hence, we opt in to error out instead of throw error. Test Plan: Run net_runner test at `/` and observe error being printed out but the test passes. Reviewed By: ipiszy Differential Revision: D24165081 fbshipit-source-id: a4e1d0479d54d741e615e3a00b3003f512394fd4 --- caffe2/opt/backend_cutting.cc | 8 ++++++-- caffe2/opt/backend_transformer_base.cc | 2 +- caffe2/opt/glow_net_transform.cc | 3 --- caffe2/opt/onnxifi_transformer.cc | 18 +++++++++++------- caffe2/utils/proto_utils.cc | 11 +++++++++-- caffe2/utils/proto_utils.h | 14 ++++++++------ 6 files changed, 35 insertions(+), 21 deletions(-) diff --git a/caffe2/opt/backend_cutting.cc b/caffe2/opt/backend_cutting.cc index e1f7808d48b9..45f46ab48330 100644 --- a/caffe2/opt/backend_cutting.cc +++ b/caffe2/opt/backend_cutting.cc @@ -352,9 +352,13 @@ void DumpGraph(NNGraph* g, const std::string& fname) { }; std::ofstream out(fname.c_str()); - out << nom::converters::convertToDotString(g, nnprinter); - out.close(); + if (out) { + out << nom::converters::convertToDotString(g, nnprinter); + } else { + LOG(ERROR) << "Cannot create nomnigraph dump file: " << fname; + } } + caffe2::NetDef OptimizeForBackend( caffe2::NetDef& net, std::function supports, diff --git a/caffe2/opt/backend_transformer_base.cc b/caffe2/opt/backend_transformer_base.cc index 7bb27fca92ab..9090e0b5277b 100644 --- a/caffe2/opt/backend_transformer_base.cc +++ b/caffe2/opt/backend_transformer_base.cc @@ -177,6 +177,6 @@ void BackendTransformerBase::dumpNet( const std::string& fname) const { NetDef shape_net(pred_net); addShapeToNet(shape_net, shape_hints); - WriteProtoToTextFile(shape_net, fname); + WriteProtoToTextFile(shape_net, fname, false); } } // namespace caffe2 diff --git a/caffe2/opt/glow_net_transform.cc b/caffe2/opt/glow_net_transform.cc index f021d263106d..12bd060c27d6 100644 --- a/caffe2/opt/glow_net_transform.cc +++ b/caffe2/opt/glow_net_transform.cc @@ -222,9 +222,6 @@ void onnxifi( OnnxifiTransformer ts(opts); ts.transform(ws, net, weight_names, more_shape_hints, more_blacklist); - if (FLAGS_onnxifi_debug_mode) { - WriteProtoToTextFile(*net, "debug_transformed_net.pb_txt"); - } // Cleanup the input from the workspace for (const auto& i : input_names) { diff --git a/caffe2/opt/onnxifi_transformer.cc b/caffe2/opt/onnxifi_transformer.cc index 9166153cf693..e849f0edb272 100644 --- a/caffe2/opt/onnxifi_transformer.cc +++ b/caffe2/opt/onnxifi_transformer.cc @@ -842,7 +842,9 @@ NetDef OnnxifiTransformer::SubnetToOnnxifiOpViaC2( int onnxifi_op_id = onnxifi_op_id_; if (opts_.debug) { WriteProtoToTextFile( - net, "debug_original_net_" + c10::to_string(onnxifi_op_id) + ".pb_txt"); + net, + "debug_original_net_" + c10::to_string(onnxifi_op_id) + ".pb_txt", + false); } if (opts_.min_ops > net.op_size()) { return net; @@ -970,10 +972,12 @@ NetDef OnnxifiTransformer::SubnetToOnnxifiOpViaC2( if (opts_.debug) { WriteProtoToTextFile( onnxifi_net, - "debug_onnxifi_net_" + c10::to_string(onnxifi_op_id) + ".pb_txt"); + "debug_onnxifi_net_" + c10::to_string(onnxifi_op_id) + ".pb_txt", + false); WriteProtoToTextFile( net_opt, - "debug_optimized_net_" + c10::to_string(onnxifi_op_id) + ".pb_txt"); + "debug_optimized_net_" + c10::to_string(onnxifi_op_id) + ".pb_txt", + false); } return net_opt; } @@ -1087,8 +1091,8 @@ NetDef OnnxifiTransformer::SubnetToOnnxifiOpViaOnnx( // Debugging stuff if (opts_.debug) { - WriteProtoToTextFile(onnx_model, "debug_onnxifi_net.onnx_txt"); - WriteProtoToTextFile(net_opt, "debug_optimized_net.pb_txt"); + WriteProtoToTextFile(onnx_model, "debug_onnxifi_net.onnx_txt", false); + WriteProtoToTextFile(net_opt, "debug_optimized_net.pb_txt", false); } return net_opt; } @@ -1467,7 +1471,7 @@ void OnnxifiTransformer::transform( CAFFE_ENFORCE(pred_net, "Predict net cannot be nullptr"); if (opts_.debug) { - WriteProtoToTextFile(*pred_net, "debug_pre_ssa_net.pb_txt"); + WriteProtoToTextFile(*pred_net, "debug_pre_ssa_net.pb_txt", false); } // Get model id and reset Onnxifi op id to 0 @@ -1548,7 +1552,7 @@ void OnnxifiTransformer::transform( addShapeToNet(*pred_net, shape_hints); if (opts_.debug) { - WriteProtoToTextFile(*pred_net, "debug_full_opt_net.pb_txt"); + WriteProtoToTextFile(*pred_net, "debug_full_opt_net.pb_txt", false); } } diff --git a/caffe2/utils/proto_utils.cc b/caffe2/utils/proto_utils.cc index e8e42e1bbf3e..5fc34d631c4f 100644 --- a/caffe2/utils/proto_utils.cc +++ b/caffe2/utils/proto_utils.cc @@ -217,10 +217,17 @@ C10_EXPORT bool ReadProtoFromTextFile(const char* filename, Message* proto) { C10_EXPORT void WriteProtoToTextFile( const Message& proto, - const char* filename) { + const char* filename, + bool throwIfError) { int fd = open(filename, O_WRONLY | O_CREAT | O_TRUNC, 0644); FileOutputStream* output = new FileOutputStream(fd); - CAFFE_ENFORCE(google::protobuf::TextFormat::Print(proto, output)); + if(!google::protobuf::TextFormat::Print(proto, output)) { + if (throwIfError) { + CAFFE_THROW("Cannot write proto to text file: ", filename); + } else { + LOG(ERROR) << "Cannot write proto to text file: " << filename; + } + } delete output; close(fd); } diff --git a/caffe2/utils/proto_utils.h b/caffe2/utils/proto_utils.h index b5e1b52f742d..35023265c982 100644 --- a/caffe2/utils/proto_utils.h +++ b/caffe2/utils/proto_utils.h @@ -80,13 +80,15 @@ inline bool ReadProtoFromTextFile(const string filename, MessageLite* proto) { inline void WriteProtoToTextFile( const MessageLite& /*proto*/, - const char* /*filename*/) { + const char* /*filename*/, + bool throwIfError = true) { LOG(FATAL) << "If you are running lite version, you should not be " << "calling any text-format protobuffers."; } inline void WriteProtoToTextFile(const MessageLite& proto, - const string& filename) { - return WriteProtoToTextFile(proto, filename.c_str()); + const string& filename, + bool throwIfError = true) { + return WriteProtoToTextFile(proto, filename.c_str(), throwIfError); } inline bool ReadProtoFromFile(const char* filename, MessageLite* proto) { @@ -115,9 +117,9 @@ inline bool ReadProtoFromTextFile(const string filename, Message* proto) { return ReadProtoFromTextFile(filename.c_str(), proto); } -CAFFE2_API void WriteProtoToTextFile(const Message& proto, const char* filename); -inline void WriteProtoToTextFile(const Message& proto, const string& filename) { - return WriteProtoToTextFile(proto, filename.c_str()); +CAFFE2_API void WriteProtoToTextFile(const Message& proto, const char* filename, bool throwIfError = true); +inline void WriteProtoToTextFile(const Message& proto, const string& filename, bool throwIfError = true) { + return WriteProtoToTextFile(proto, filename.c_str(), throwIfError); } // Read Proto from a file, letting the code figure out if it is text or binary. From c9caa828f5320fe635b6e785cc007654a4177a1b Mon Sep 17 00:00:00 2001 From: Yinghai Lu Date: Thu, 8 Oct 2020 00:43:55 -0700 Subject: [PATCH 59/69] Throw special exception when backend compilation is met with fatal error (#45952) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45952 Pull Request resolved: https://github.com/pytorch/glow/pull/4967 When glow compilation meets with nonrecoverable fatal error (hardware is busted), we would like to throw a special exception other than the normal caffe2::EnforceNotMet so that we can signal the upper layer application to handle it differently. Test Plan: Manually code some error and add LOG(FATAL) in the special exception path and wait for application to fatal. Reviewed By: ipiszy Differential Revision: D24156792 fbshipit-source-id: 4ae21bb0d36c89eac331fc52dd4682826b3ea180 --- c10/util/Exception.h | 8 +++++++- caffe2/opt/onnxifi_op.h | 35 +++++++++++++++++++++-------------- 2 files changed, 28 insertions(+), 15 deletions(-) diff --git a/c10/util/Exception.h b/c10/util/Exception.h index 4b55c562130c..3a80cd1d3fb4 100644 --- a/c10/util/Exception.h +++ b/c10/util/Exception.h @@ -181,6 +181,12 @@ class C10_API EnforceFiniteError : public Error { using Error::Error; }; +// Used in Onnxifi backend lowering. These turn into +// ExitException when they cross to Python. +class C10_API OnnxfiBackendSystemError : public Error { + using Error::Error; +}; + // A utility function to return an exception std::string by prepending its // exception type before its what() content C10_API std::string GetExceptionString(const std::exception& e); @@ -340,7 +346,7 @@ inline std::string if_empty_then(std::string x, std::string y) { #endif #define TORCH_CHECK(cond, ...) TORCH_CHECK_WITH(Error, cond, __VA_ARGS__) -// An utility macro that does what `TORCH_CHECK` does if compiled in the host code, +// An utility macro that does what `TORCH_CHECK` does if compiled in the host code, // otherwise does nothing. Supposed to be used in the code shared between host and // device code as an alternative for `TORCH_CHECK`. #if defined(__CUDACC__) || defined(__HIPCC__) diff --git a/caffe2/opt/onnxifi_op.h b/caffe2/opt/onnxifi_op.h index f19403a14e58..865fdf301ca1 100644 --- a/caffe2/opt/onnxifi_op.h +++ b/caffe2/opt/onnxifi_op.h @@ -4,6 +4,7 @@ #include "onnx/onnx_pb.h" +#include "c10/util/Exception.h" #include "c10/util/SmallVector.h" #include "caffe2/core/context.h" #include "caffe2/core/logging.h" @@ -65,7 +66,7 @@ class OnnxifiOp final : public Operator { CAFFE_ENFORCE(!onnx_model_str.empty(), "onnx_model cannot be empty"); if (use_glow_aot_) { auto netdef_str = - this->template GetSingleArgument("netdef_str", ""); + this->template GetSingleArgument("netdef_str", ""); CAFFE_ENFORCE(ParseProtoFromLargeString(netdef_str, &netdef_)); } else if (!use_onnx_) { CAFFE_ENFORCE(ParseProtoFromLargeString(onnx_model_str, &netdef_)); @@ -187,7 +188,7 @@ class OnnxifiOp final : public Operator { this->template GetRepeatedArgument("initializers"); // Build the Onnxifi engine auto backend_index = - this->template GetSingleArgument("backend_id", use_onnx_ ? 1 : 0); + this->template GetSingleArgument("backend_id", use_onnx_ ? 1 : 0); // If using Glow AOT, override the backend_id to 1, since it uses a custom // ONNX format, and that's the id we use for the ONNX backend. if (use_glow_aot_) { @@ -266,18 +267,24 @@ class OnnxifiOp final : public Operator { static const uint64_t auxPropertiesListAOT[] = { ONNXIFI_OPTIMIZATION_AOT, ONNXIFI_GRAPH_PROPERTY_NONE}; - CAFFE_ENFORCE_EQ( - lib_->onnxInitGraph( - backend, - use_glow_aot_ ? auxPropertiesListAOT : nullptr, - onnx_model_str.size(), - (const void*)(onnx_model_str.c_str()), - weight_descs.size(), - weight_descs.data(), - &graph, - static_cast(max_seq_size_), - defered_blob_reader), - ONNXIFI_STATUS_SUCCESS); + auto ret = lib_->onnxInitGraph( + backend, + use_glow_aot_ ? auxPropertiesListAOT : nullptr, + onnx_model_str.size(), + (const void*)(onnx_model_str.c_str()), + weight_descs.size(), + weight_descs.data(), + &graph, + static_cast(max_seq_size_), + defered_blob_reader); + if (ret != ONNXIFI_STATUS_SUCCESS) { + if (ret == ONNXIFI_STATUS_FATAL_ERROR) { + C10_THROW_ERROR( + OnnxfiBackendSystemError, "Fatal error during onnxInitGraph"); + } else { + CAFFE_THROW("onnxInitGraph failed"); + } + } return std::make_shared( backend_id, backend, graph, lib_, std::move(weight_shape_info)); From b65ffa365cee426e631daf259524ec990b2f4f3d Mon Sep 17 00:00:00 2001 From: Mikhail Zolotukhin Date: Thu, 8 Oct 2020 00:47:00 -0700 Subject: [PATCH 60/69] [TensorExpr] Nuke `Function` class and directly use `Tensor` instead. (#45936) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45936 `Tensor` has been a view into a `Function` that was supposed to be used for a more general case when we have multiple computations over the same domain (aka multiple output functions). We have never got to a point where we need this and now have other ideas in mind on how to support this case if need be. For now, let's just nuke `Function` to reduce the overall system complexity. The change should not affect any existing behavior. Test Plan: Imported from OSS Reviewed By: bertmaher Differential Revision: D24153214 Pulled By: ZolotukhinM fbshipit-source-id: 26d5f11db5d661ff5e1135f4a49eff1c6d4c1bd5 --- test/cpp/tensorexpr/tutorial.cpp | 52 ++------ torch/csrc/jit/tensorexpr/codegen.h | 12 +- torch/csrc/jit/tensorexpr/expr.h | 3 + torch/csrc/jit/tensorexpr/ir_printer.cpp | 29 ----- torch/csrc/jit/tensorexpr/ir_printer.h | 5 - torch/csrc/jit/tensorexpr/loopnest.cpp | 43 ++++--- torch/csrc/jit/tensorexpr/tensor.cpp | 25 ++-- torch/csrc/jit/tensorexpr/tensor.h | 157 ++++++++--------------- 8 files changed, 101 insertions(+), 225 deletions(-) diff --git a/test/cpp/tensorexpr/tutorial.cpp b/test/cpp/tensorexpr/tutorial.cpp index f0bcfc4c2485..31e05549186e 100644 --- a/test/cpp/tensorexpr/tutorial.cpp +++ b/test/cpp/tensorexpr/tutorial.cpp @@ -125,54 +125,30 @@ int main(int argc, char* argv[]) { // independent computations over the same domain) for its elements, as a // function of indices // - // We use Function objects to represent this. Let's build one. - // - // First, we need to specify the domain, or dimensions in which the - // computation would be performed. Let's create a 64x32 domain: + // TODO: Update this section once Tensor/Function cleanup is done std::vector dims = { new IntImm(64), new IntImm(32)}; // IntImm stands for Integer Immediate // and represents an integer constant - // Next we need to create Function arguments. The arguments of a Function - // are Vars, and they play role of placeholders. The computation that the - // function would describe would use these arguments. + // Next we need to create arguments. The arguments are Vars, and they play + // role of placeholders. The computation that the tensor would describe + // would use these arguments. const Var* i = new Var("i", kInt); const Var* j = new Var("j", kInt); std::vector args = {i, j}; - // Now we can define the function computations using these arguments. Let's - // create two computations, the first would add the arguments of the - // function, the second would multiply them. - Expr* func_body1 = new Mul(i, j); - Expr* func_body2 = new Add(i, j); - - // Finally, we pass all these pieces together to Function constructor: - Function* func = - new Function({"X", "Y"}, dims, args, {func_body1, func_body2}); - // Under the hood function constructor would create separate `Buf` - // expressions for each computation (which can be accessed via - // `func->func_var(idx)`) with the names specified by the first parameter of - // the constructor call. In our example two `Buf` variables will be created - // with names 'X' and 'Y', each of them would signify a domain of 64x32. - - // We can now print out our function: - std::cout << "Tensor function: " << *func << std::endl; - // Prints: - // Tensor function: Function F(i[64], j[32]) { - // X = i * j - // Y = i + j - // } + // Now we can define the body of the tensor computation using these + // arguments. + Expr* body = new Mul(i, j); - // A Tensor refers to an individual computation defined by a Function. For - // instance, we could create a following tensor given the function above: - int output_idx = 0; // Used to index the computation - Tensor* X = new Tensor(func, output_idx); + // Finally, we pass all these pieces together to Tensor constructor: + Tensor* X = new Tensor("X", dims, args, body); std::cout << "Tensor computation: " << *X << std::endl; // Prints: Tensor computation: Tensor X(i[64], j[32]) = i * j // Similarly to how we provide a more convenient way of using handles for // constructing Exprs, Tensors also have a more convenient API for - // construction. It is based on Compute functions, which take a name: + // construction. It is based on Compute API, which takes a name, // dimensions, and a lambda specifying the computation body: Tensor* Z = Compute( "Z", @@ -204,14 +180,6 @@ int main(int argc, char* argv[]) { // Tensor and we use 'load' for accessing elements of an external tensor // through its Placeholder. This is an implementation detail and could be // changed in future. - // - // Why do we have Functions and Tensors and what is the relationship between - // them? Functions are used to represent several computations performed over - // the same domain. Tensors refer to individual computations of a Function. - // - // Also note that currently a lot of code only supports single-output - // Functions, in which case they become almost identical to Tensors. This - // probably will be changed in future. // TODO: Show how reductions are represented and constructed } diff --git a/torch/csrc/jit/tensorexpr/codegen.h b/torch/csrc/jit/tensorexpr/codegen.h index 4bf9d7680ad0..a32b362fe3cd 100644 --- a/torch/csrc/jit/tensorexpr/codegen.h +++ b/torch/csrc/jit/tensorexpr/codegen.h @@ -73,17 +73,7 @@ class CodeGen::BufferArg { BufferArg(const Placeholder& buffer) : var_(buffer.data()->base_handle()), dtype_(buffer.dtype()) {} BufferArg(Tensor* tensor) - : var_(tensor->function() - ->func_var(tensor->output_index()) - ->base_handle()), - dtype_(tensor->function()->body(tensor->output_index())->dtype()) {} - BufferArg(const Function& func) - : var_(func.func_var(0)->base_handle()), dtype_(func.body(0)->dtype()) { - // TODO: Support multiple-output functions - if (func.func_vars().size() != 1) { - throw unimplemented_lowering(); - } - } + : var_(tensor->buf()->base_handle()), dtype_(tensor->body()->dtype()) {} BufferArg(const VarHandle& var) : var_(var.node()), dtype_(var.dtype()), isVar_(true) {} diff --git a/torch/csrc/jit/tensorexpr/expr.h b/torch/csrc/jit/tensorexpr/expr.h index 434aa52db815..7c64403a10bd 100644 --- a/torch/csrc/jit/tensorexpr/expr.h +++ b/torch/csrc/jit/tensorexpr/expr.h @@ -194,6 +194,9 @@ class TORCH_API Buf : public ExprNode { return dims_.size(); } const Expr* dim(size_t index) const { + if (index >= ndim()) { + throw out_of_range_index(); + } return dims_[index]; } std::vector dims() const { diff --git a/torch/csrc/jit/tensorexpr/ir_printer.cpp b/torch/csrc/jit/tensorexpr/ir_printer.cpp index 8f47d855500f..5792729ac7d9 100644 --- a/torch/csrc/jit/tensorexpr/ir_printer.cpp +++ b/torch/csrc/jit/tensorexpr/ir_printer.cpp @@ -552,11 +552,6 @@ std::ostream& operator<<(std::ostream& stream, const Tensor& t) { return stream; } -std::ostream& operator<<(std::ostream& stream, const Function& f) { - stream << std::to_string(&f); - return stream; -} - void print(const Expr* expr) { if (expr) { IRPrinter p(std::cout); @@ -579,10 +574,6 @@ void print(const Tensor* t) { std::cout << std::to_string(t); } -void print(const Function* f) { - std::cout << std::to_string(f); -} - } // namespace tensorexpr } // namespace jit } // namespace torch @@ -615,24 +606,4 @@ std::string to_string(const Tensor* t) { oss << ") = " << *t->body() << "\n"; return oss.str(); } - -std::string to_string(const Function* f) { - if (!f) { - return "(null function)\n"; - } - std::ostringstream oss; - oss << "Function F("; - for (size_t i = 0; i < f->ndim(); i++) { - if (i != 0) { - oss << ", "; - } - oss << *f->arg(i) << "[" << *f->dim(i) << "]"; - } - oss << ") {\n"; - for (size_t i = 0; i < f->bodies().size(); i++) { - oss << " " << *f->func_var(i) << " = " << *f->body(i) << "\n"; - } - oss << "}\n"; - return oss.str(); -} } // namespace std diff --git a/torch/csrc/jit/tensorexpr/ir_printer.h b/torch/csrc/jit/tensorexpr/ir_printer.h index 64ba35280371..d9079d7fb717 100644 --- a/torch/csrc/jit/tensorexpr/ir_printer.h +++ b/torch/csrc/jit/tensorexpr/ir_printer.h @@ -11,7 +11,6 @@ namespace jit { namespace tensorexpr { class Tensor; -class Function; class TORCH_API IRPrinter : public IRVisitor { public: @@ -95,12 +94,10 @@ TORCH_API std::ostream& operator<<(std::ostream& stream, const Expr&); TORCH_API std::ostream& operator<<(std::ostream& stream, const ExprHandle&); TORCH_API std::ostream& operator<<(std::ostream& stream, const Stmt&); TORCH_API std::ostream& operator<<(std::ostream& stream, const Tensor&); -TORCH_API std::ostream& operator<<(std::ostream& stream, const Function&); TORCH_API void print(const Expr* expr); TORCH_API void print(const Stmt* stmt); TORCH_API void print(const Tensor* t); -TORCH_API void print(const Function* f); } // namespace tensorexpr } // namespace jit @@ -109,12 +106,10 @@ TORCH_API void print(const Function* f); namespace std { using torch::jit::tensorexpr::Expr; -using torch::jit::tensorexpr::Function; using torch::jit::tensorexpr::Stmt; using torch::jit::tensorexpr::Tensor; TORCH_API std::string to_string(const Expr* expr); TORCH_API std::string to_string(const Stmt* stmt); TORCH_API std::string to_string(const Tensor* t); -TORCH_API std::string to_string(const Function* f); } // namespace std diff --git a/torch/csrc/jit/tensorexpr/loopnest.cpp b/torch/csrc/jit/tensorexpr/loopnest.cpp index 456a264006e1..301f11fbb5e2 100644 --- a/torch/csrc/jit/tensorexpr/loopnest.cpp +++ b/torch/csrc/jit/tensorexpr/loopnest.cpp @@ -449,11 +449,9 @@ LoopNest::LoopNest(const std::vector& output_tensors) } Stmt* LoopNest::lowerToStmt(Tensor* t) { - Function* f = t->function(); - // TODO: Support multiple-output functions - Stmt* body = f->ElementStmt(0); + Stmt* body = t->ElementStmt(); - if (f->ndim() == 0) { + if (t->ndim() == 0 && t->reduce_ndim() == 0) { return body; } @@ -461,18 +459,30 @@ Stmt* LoopNest::lowerToStmt(Tensor* t) { if (initializer) { buf_initializers_[t->buf()] = initializer; } + std::vector indices(t->args().begin(), t->args().end()); - for (size_t i = 0; i < f->ndim(); i++) { - // Going in reverse order: from innermost loop to the outermost - size_t dim_index = f->ndim() - i - 1; - body = new For(f->arg(dim_index), new IntImm(0), f->dim(dim_index), body); - indices.pop_back(); - if (initializer && indices.size() == t->ndim()) { + if (t->reduce_ndim() > 0) { + for (size_t i = 0; i < t->reduce_ndim(); i++) { + // Going in reverse order: from innermost loop to the outermost + size_t dim_index = t->reduce_ndim() - i - 1; + body = new For( + t->reduce_arg(dim_index), + new IntImm(0), + t->reduce_dim(dim_index), + body); + } + if (initializer) { Store* init = new Store(t->buf(), indices, initializer, new IntImm(1)); body = new Block({init, body}); } } + + for (size_t i = 0; i < t->ndim(); i++) { + // Going in reverse order: from innermost loop to the outermost + size_t dim_index = t->ndim() - i - 1; + body = new For(t->arg(dim_index), new IntImm(0), t->dim(dim_index), body); + } return body; } @@ -493,26 +503,21 @@ class FunctionInliner : public IRMutator { // For the target function, insert the caller/callee pair into the replacement // mapping. const Expr* mutate(const FunctionCall* v) override { - Function* func = v->tensor()->function(); - const Buf* buf = v->tensor()->buf(); + const Tensor* t = v->tensor(); + const Buf* buf = t->buf(); if (buf != buf_) { return IRMutator::mutate(v); } - // TODO: Support multiple-output functions - if (func->func_vars().size() != 1) { - throw unimplemented_lowering(); - } - if (v->nparams() != buf->ndim()) { throw malformed_input( "Placeholder indexed access is inconsistent with its rank", v); } std::vector index_vars; - TORCH_INTERNAL_ASSERT(buf->ndim() == func->args().size()); + TORCH_INTERNAL_ASSERT(buf->ndim() == t->args().size()); for (size_t i = 0; i < buf->ndim(); i++) { - const Var* func_callee_arg = dynamic_cast(func->arg(i)); + const Var* func_callee_arg = dynamic_cast(t->arg(i)); const Expr* func_caller_param = v->param(i); auto iter = inline_mapping_.find(func_callee_arg); if (iter != inline_mapping_.end()) { diff --git a/torch/csrc/jit/tensorexpr/tensor.cpp b/torch/csrc/jit/tensorexpr/tensor.cpp index 4fad4cac9a6d..4afc1ffeefb5 100644 --- a/torch/csrc/jit/tensorexpr/tensor.cpp +++ b/torch/csrc/jit/tensorexpr/tensor.cpp @@ -16,8 +16,7 @@ Tensor* Compute( std::vector args; unpack_dim_args(dim_args, &dims, &args); const Expr* body = body_func(VarVectorToVarHandleVector(args)).node(); - Function* func = new Function(func_name, dims, args, body); - return new Tensor(func, 0); + return new Tensor(func_name, dims, args, body); } Tensor* Compute( @@ -32,8 +31,7 @@ Tensor* Compute( std::vector args; unpack_dim_args(dim_args, &dims, &args); const Expr* body = body_func(VarHandle(args[0])).node(); - Function* func = new Function(func_name, dims, args, body); - return new Tensor(func, 0); + return new Tensor(func_name, dims, args, body); } Tensor* Compute( @@ -48,8 +46,7 @@ Tensor* Compute( std::vector args; unpack_dim_args(dim_args, &dims, &args); const Expr* body = body_func(VarHandle(args[0]), VarHandle(args[1])).node(); - Function* func = new Function(func_name, dims, args, body); - return new Tensor(func, 0); + return new Tensor(func_name, dims, args, body); } Tensor* Compute( @@ -67,8 +64,7 @@ Tensor* Compute( const Expr* body = body_func(VarHandle(args[0]), VarHandle(args[1]), VarHandle(args[2])) .node(); - Function* func = new Function(func_name, dims, args, body); - return new Tensor(func, 0); + return new Tensor(func_name, dims, args, body); } Tensor* Compute( @@ -87,20 +83,17 @@ Tensor* Compute( unpack_dim_args(dim_args, &dims, &args_nodes); auto args = VarVectorToVarHandleVector(args_nodes); const Expr* body = body_func(args[0], args[1], args[2], args[3]).node(); - Function* func = new Function(func_name, dims, args_nodes, body); - return new Tensor(func, 0); + return new Tensor(func_name, dims, args_nodes, body); } -Stmt* Function::ElementStmt(size_t index) { - const Buf* buf = func_var(index); +Stmt* Tensor::ElementStmt() { std::vector indices; - for (size_t i = 0; i < buf->ndim(); i++) { - indices.push_back(this->args_[i]); + for (size_t i = 0; i < buf_->ndim(); i++) { + indices.push_back(args_[i]); } const Expr* mask = new IntImm(1); - - Stmt* update_stmt = new Store(buf, indices, body(index), mask); + Stmt* update_stmt = new Store(buf_, indices, body_, mask); return update_stmt; } diff --git a/torch/csrc/jit/tensorexpr/tensor.h b/torch/csrc/jit/tensorexpr/tensor.h index 9d0cadc52686..d37f14c3a606 100644 --- a/torch/csrc/jit/tensorexpr/tensor.h +++ b/torch/csrc/jit/tensorexpr/tensor.h @@ -12,129 +12,80 @@ namespace torch { namespace jit { namespace tensorexpr { -class Function : public KernelScopedObject { +class Tensor : KernelScopedObject { public: - Function( - const std::string& func_name, + Tensor( + const std::string& name, const std::vector& dims, const std::vector& args, const Expr* body) // TODO: Function should not create buffers, they should be created // manually before constructing a function. - : func_vars_({new Buf(func_name, dims, body->dtype())}), - dims_(dims), - args_(args), - bodies_({body}) {} - Function( - const std::vector& func_names, - const std::vector& dims, - const std::vector& args, - const std::vector& bodies) - : func_vars_(func_names.size()), - dims_(dims), - args_(args), - bodies_(bodies) { - for (size_t i = 0; i < func_names.size(); i++) { - func_vars_[i] = new Buf(func_names[i], dims, bodies[i]->dtype()); - } - } - Function( - const std::string& func_name, - Buf* func_var, - const std::vector& dims, + : buf_(new Buf(name, dims, body->dtype())), args_(args), body_(body) {} + + Tensor(Buf* buf, const std::vector& args, const Expr* body) + : buf_(buf), args_(args), body_(body) {} + + Tensor( + Buf* buf, const std::vector& args, + const std::vector& reduce_dims, + const std::vector& reduce_args, const Expr* body) - : func_vars_({func_var}), dims_(dims), args_(args), bodies_({body}) {} + : buf_(buf), + args_(args), + body_(body), + reduce_dims_(reduce_dims), + reduce_args_(reduce_args) {} + // Wrappers over accessors to fields of the underlying function + const Expr* body() const { + return body_; + } + const Buf* buf() const { + return buf_; + } size_t ndim() const { - return dims_.size(); + return buf()->ndim(); } - const Expr* dim(size_t index) const { - if (index < 0 || index >= dims_.size()) { + if (index >= ndim()) { throw out_of_range_index(); } - - return dims_[index]; + return buf()->dim(index); } - const std::vector& dims() const { - return dims_; + std::vector dims() const { + return buf()->dims(); } - const Var* arg(size_t index) const { - if (index < 0 || index >= args_.size()) { + if (index >= ndim()) { throw out_of_range_index(); } - return args_[index]; } const std::vector& args() const { return args_; } - - std::vector bodies() const { - return bodies_; + size_t reduce_ndim() const { + return reduce_dims_.size(); } - const Expr* body(size_t index) const { - if (index >= bodies_.size()) { - throw out_of_range_index(); - } - - return bodies_[index]; + std::vector reduce_dims() const { + return reduce_dims_; } - - std::vector func_vars() const { - return func_vars_; + std::vector reduce_args() const { + return reduce_args_; } - const Buf* func_var(size_t index) const { - if (index >= func_vars_.size()) { + const Expr* reduce_dim(size_t index) const { + if (index >= reduce_ndim()) { throw out_of_range_index(); } - return func_vars_[index]; - } - - Stmt* ElementStmt(size_t index); - - private: - std::vector func_vars_; - std::vector dims_; - std::vector args_; - std::vector bodies_; -}; - -class Tensor : KernelScopedObject { - public: - Tensor(Function* function, int output_index) - : function_(function), output_index_(output_index) {} - - Function* function() const { - return function_; - } - int output_index() const { - return output_index_; - } - - // Wrappers over accessors to fields of the underlying function - const Expr* body() const { - return function()->body(output_index()); - } - const Buf* buf() const { - return function()->func_var(output_index()); - } - int ndim() const { - return buf()->dims().size(); - } - const Expr* dim(int index) const { - return buf()->dim(index); + return reduce_dims_[index]; } - std::vector dims() const { - return buf()->dims(); - } - const Var* arg(int index) const { - return function()->arg(index); - } - const std::vector& args() const { - return function()->args(); + const Var* reduce_arg(size_t index) const { + if (index >= reduce_ndim()) { + throw out_of_range_index(); + } + return reduce_args_[index]; } void initializeTo(const Expr* initializer) { @@ -143,6 +94,7 @@ class Tensor : KernelScopedObject { const Expr* initializer() const { return initializer_; } + Stmt* ElementStmt(); template inline ExprHandle operator()(const Ts&... ts); @@ -152,8 +104,12 @@ class Tensor : KernelScopedObject { inline ExprHandle call(const Ts&... ts); private: - Function* function_; - int output_index_; + const Buf* buf_; + std::vector args_; + const Expr* body_; + std::vector reduce_dims_; + std::vector reduce_args_; + const Expr* initializer_{nullptr}; }; @@ -295,10 +251,8 @@ Tensor* Reduce( Buf* func_result = new Buf(func_name, dims, body.dtype()); const ReduceOp* reduce_op = reducer(func_result, body, output_args, reduce_vars); - dims.insert(dims.end(), reduce_dims.begin(), reduce_dims.end()); - Function* func = - new Function(func_name, func_result, dims, all_vars, reduce_op); - Tensor* t = new Tensor(func, 0); + Tensor* t = + new Tensor(func_result, vars, reduce_dims, reduce_vars, reduce_op); t->initializeTo(new Cast(body.dtype(), reducer.initializer())); return t; } @@ -352,10 +306,7 @@ class FunctionCall : public CallNode { } FunctionCall(Tensor* tensor, const std::vector& params) - : BaseClass( - tensor->function()->body(tensor->output_index())->dtype(), - kFunctionCall, - params), + : BaseClass(tensor->body()->dtype(), kFunctionCall, params), tensor_(tensor) {} private: From 598caddd933c5f983c2d2e4899e4656475b8a3fa Mon Sep 17 00:00:00 2001 From: Mikhail Zolotukhin Date: Thu, 8 Oct 2020 00:47:00 -0700 Subject: [PATCH 61/69] [TensorExpr] Add shorthand versions for `splitWith{Mask,Tail}` functions. (#45946) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45946 Also, make these functions static - they are not using anything from `LoopNest` and can be applied to any `Stmt`. Test Plan: Imported from OSS Reviewed By: bertmaher Differential Revision: D24156002 Pulled By: ZolotukhinM fbshipit-source-id: 1c7d205f85a2a1684e07eb836af662f10d0a50fc --- torch/csrc/jit/tensorexpr/loopnest.cpp | 12 +++++++++--- torch/csrc/jit/tensorexpr/loopnest.h | 12 ++++++++++-- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/torch/csrc/jit/tensorexpr/loopnest.cpp b/torch/csrc/jit/tensorexpr/loopnest.cpp index 301f11fbb5e2..68a392662a5f 100644 --- a/torch/csrc/jit/tensorexpr/loopnest.cpp +++ b/torch/csrc/jit/tensorexpr/loopnest.cpp @@ -989,6 +989,11 @@ void LoopNest::sliceTail(For* f, int factor, For** head, For** tail) { // TODO: record history of transformations } +void LoopNest::splitWithTail(For* f, int factor) { + For *outer, *inner, *tail; + splitWithTail(f, factor, &outer, &inner, &tail); +} + void LoopNest::splitWithTail( For* f, int factor, @@ -1054,8 +1059,11 @@ void LoopNest::splitWithTail( } else { *tail = nullptr; } +} - // TODO: record history of transformations +void LoopNest::splitWithMask(For* f, int factor) { + For *outer, *inner; + splitWithMask(f, factor, &outer, &inner); } void LoopNest::splitWithMask(For* f, int factor, For** outer, For** inner) { @@ -1115,8 +1123,6 @@ void LoopNest::splitWithMask(For* f, int factor, For** outer, For** inner) { // TODO: cleanup API for adding/removing statements p->replace_stmt(f, *outer); - - // TODO: record history of transformations } For* findOuterFor(For* a, For* b) { diff --git a/torch/csrc/jit/tensorexpr/loopnest.h b/torch/csrc/jit/tensorexpr/loopnest.h index 391bdbeb1c37..911cad93c5a1 100644 --- a/torch/csrc/jit/tensorexpr/loopnest.h +++ b/torch/csrc/jit/tensorexpr/loopnest.h @@ -39,8 +39,16 @@ class TORCH_API LoopNest { void computeInline(Stmt* s); void computeInline(const Buf* b); - void splitWithTail(For* f, int factor, For** outer, For** inner, For** tail); - void splitWithMask(For* f, int factor, For** outer, For** inner); + static void splitWithTail(For* f, int factor); + static void splitWithTail( + For* f, + int factor, + For** outer, + For** inner, + For** tail); + + static void splitWithMask(For* f, int factor); + static void splitWithMask(For* f, int factor, For** outer, For** inner); void reorderAxis(For* a, For* b); From 29da553dd933996a8b30f1ddeecaa22b5c6eb8f3 Mon Sep 17 00:00:00 2001 From: Mikhail Zolotukhin Date: Thu, 8 Oct 2020 00:47:00 -0700 Subject: [PATCH 62/69] [TensorExpr] Loopnest: unify intermediate_tensors_ and temp_bufs_. (#45947) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45947 Test Plan: Imported from OSS Reviewed By: bertmaher Differential Revision: D24155999 Pulled By: ZolotukhinM fbshipit-source-id: d82acf6aba570f6a675eea683c306088e2a41f91 --- torch/csrc/jit/tensorexpr/loopnest.cpp | 46 +++++--------------------- torch/csrc/jit/tensorexpr/loopnest.h | 3 +- 2 files changed, 10 insertions(+), 39 deletions(-) diff --git a/torch/csrc/jit/tensorexpr/loopnest.cpp b/torch/csrc/jit/tensorexpr/loopnest.cpp index 68a392662a5f..39a45d427862 100644 --- a/torch/csrc/jit/tensorexpr/loopnest.cpp +++ b/torch/csrc/jit/tensorexpr/loopnest.cpp @@ -427,7 +427,7 @@ LoopNest::LoopNest(const std::vector& output_tensors) tensors_to_compute.begin(), tensors_to_compute.end()); for (Tensor* t : tensors_to_compute) { if (!output_tensors_.count(t)) { - intermediate_tensors_.insert(t); + intermediate_bufs_.insert(t->buf()); } } @@ -679,19 +679,7 @@ void LoopNest::computeInline(const Buf* b) { root_stmt_ = root_stmt_->accept_mutator(&inliner); // No longer computing this intermediate tensor, so don't alloc it. - for (auto* t : intermediate_tensors_) { - if (b == t->buf()) { - intermediate_tensors_.erase(t); - break; - } - } - - for (auto it = temp_bufs_.begin(); it != temp_bufs_.end(); ++it) { - if (b == *it) { - temp_bufs_.erase(it); - break; - } - } + intermediate_bufs_.erase(b); } // TODO: Unify with DepTracker @@ -788,9 +776,7 @@ Block* findLowestContainingBlock(const std::vector& uses) { } Stmt* LoopNest::insertAllocFree(Stmt* stmt) { - // Add allocs and frees for intermediate buffers at the global level. - // TODO: move allocs and frees to the imemediate areas to reuse buffers. - if (intermediate_tensors_.size() == 0ULL && temp_bufs_.size() == 0ULL) { + if (intermediate_bufs_.size() == 0ULL) { return stmt; } @@ -799,31 +785,17 @@ Stmt* LoopNest::insertAllocFree(Stmt* stmt) { b = new Block({stmt}); } - // TODO: Fix the traversal, currently the order is non-deterministic - for (Tensor* tensor : intermediate_tensors_) { - if (output_tensors_.count(tensor) > 0) { - // No need to allocate memory if the tensors are given as input/output. - continue; - } - Stmt* alloc = new Allocate( - tensor->buf()->base_handle(), tensor->body()->dtype(), tensor->dims()); - Stmt* free = new Free(tensor->buf()->base_handle()); - b->prepend_stmt(alloc); - b->append_stmt(free); - } - - // Now insert allocations and frees for temporary buffers. Do that in the - // innermost possible scope. std::unordered_map> uses = findUses(stmt); - - for (const auto& buf : temp_bufs_) { + // Insert allocations and frees for temporary buffers in the innermost + // possible scope. + for (const Buf* buf : intermediate_bufs_) { Stmt* alloc = new Allocate(buf->base_handle(), buf->dtype(), buf->dims()); Stmt* free = new Free(buf->base_handle()); - Block* alloc_block = findLowestContainingBlock(uses.at(buf)); alloc_block->prepend_stmt(alloc); alloc_block->append_stmt(free); } + return b; } @@ -1655,7 +1627,7 @@ void LoopNest::computeAt(Stmt* s, For* f) { // Mark the new temp buffer as requiring an alloc (it will be inserted as a // part of prepareForCodegen). - temp_bufs_.emplace_back(temp_buf); + intermediate_bufs_.insert(temp_buf); } class SwapReduce : public IRMutator { @@ -1933,7 +1905,7 @@ void LoopNest::rfactor( } tmp_buf->set_dims(tmp_dims); - temp_bufs_.emplace_back(tmp_buf); + intermediate_bufs_.insert(tmp_buf); } } // namespace tensorexpr diff --git a/torch/csrc/jit/tensorexpr/loopnest.h b/torch/csrc/jit/tensorexpr/loopnest.h index 911cad93c5a1..93d0593c92a2 100644 --- a/torch/csrc/jit/tensorexpr/loopnest.h +++ b/torch/csrc/jit/tensorexpr/loopnest.h @@ -92,8 +92,7 @@ class TORCH_API LoopNest { Stmt* root_stmt_; std::unordered_set output_tensors_; - std::unordered_set intermediate_tensors_; - std::vector temp_bufs_; + std::unordered_set intermediate_bufs_; // Holds the initializer Expr of buffers that have been initialized. std::unordered_map buf_initializers_; }; From 1036b77416bf0d46540ff36d5ca0911299ee74b7 Mon Sep 17 00:00:00 2001 From: Mikhail Zolotukhin Date: Thu, 8 Oct 2020 00:47:00 -0700 Subject: [PATCH 63/69] [TensorExpr] LoopNest: replace output_tensors_ with output_bufs_. (#45948) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45948 No functionality changes expected, it's just a preparation for further changes in the LoopNest interface. Test Plan: Imported from OSS Reviewed By: bertmaher Differential Revision: D24156000 Pulled By: ZolotukhinM fbshipit-source-id: f95ab07aac0aba128bc4ed5376a3251ac9c31c06 --- torch/csrc/jit/tensorexpr/loopnest.cpp | 15 ++++++++------- torch/csrc/jit/tensorexpr/loopnest.h | 2 +- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/torch/csrc/jit/tensorexpr/loopnest.cpp b/torch/csrc/jit/tensorexpr/loopnest.cpp index 39a45d427862..5c093707dc95 100644 --- a/torch/csrc/jit/tensorexpr/loopnest.cpp +++ b/torch/csrc/jit/tensorexpr/loopnest.cpp @@ -414,19 +414,22 @@ std::vector LoopNest::findAllNeededTensors( return result; } -LoopNest::LoopNest(const std::vector& output_tensors) - : output_tensors_(output_tensors.begin(), output_tensors.end()) { +LoopNest::LoopNest(const std::vector& output_tensors) { // Find all tensors we need to compute (including dependencies) and put them // in a topological order std::vector tensors_to_compute = findAllNeededTensors(output_tensors); + for (auto t : output_tensors) { + output_bufs_.insert(t->buf()); + } + // Find all intermediate tensors, we'll need that for inserting alloc/free // statements std::unordered_set tensors_to_compute_set( tensors_to_compute.begin(), tensors_to_compute.end()); for (Tensor* t : tensors_to_compute) { - if (!output_tensors_.count(t)) { + if (!output_bufs_.count(t->buf())) { intermediate_bufs_.insert(t->buf()); } } @@ -653,10 +656,8 @@ void LoopNest::computeInline(Stmt* s) { } void LoopNest::computeInline(const Buf* b) { - for (auto* t : output_tensors_) { - if (b == t->buf()) { - throw std::logic_error("Can't inline producers of output Tensors"); - } + if (output_bufs_.count(b)) { + throw std::logic_error("Can't inline producers of output Tensors"); } // Find producers. diff --git a/torch/csrc/jit/tensorexpr/loopnest.h b/torch/csrc/jit/tensorexpr/loopnest.h index 93d0593c92a2..06a0691abdfd 100644 --- a/torch/csrc/jit/tensorexpr/loopnest.h +++ b/torch/csrc/jit/tensorexpr/loopnest.h @@ -91,7 +91,7 @@ class TORCH_API LoopNest { Stmt* root_stmt_; - std::unordered_set output_tensors_; + std::unordered_set output_bufs_; std::unordered_set intermediate_bufs_; // Holds the initializer Expr of buffers that have been initialized. std::unordered_map buf_initializers_; From 6e4de445010ef5e9e0438e4a4c16f6ef3129af14 Mon Sep 17 00:00:00 2001 From: Mikhail Zolotukhin Date: Thu, 8 Oct 2020 00:47:00 -0700 Subject: [PATCH 64/69] [TensorExpr] LoopNest: add a constructor that takes Stmt instead of list of Tensors. (#45949) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45949 Test Plan: Imported from OSS Reviewed By: bertmaher Differential Revision: D24156001 Pulled By: ZolotukhinM fbshipit-source-id: 6f4f050b04e802e274c42ed64be74c21ba79c29f --- torch/csrc/jit/tensorexpr/loopnest.h | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/torch/csrc/jit/tensorexpr/loopnest.h b/torch/csrc/jit/tensorexpr/loopnest.h index 06a0691abdfd..8eebf82b9886 100644 --- a/torch/csrc/jit/tensorexpr/loopnest.h +++ b/torch/csrc/jit/tensorexpr/loopnest.h @@ -24,7 +24,22 @@ class Dtype; class TORCH_API LoopNest { public: + // A constructor for building a LoopNest from a list of Tensors LoopNest(const std::vector& output_tensors); + + // A constructor for building a LoopNest from a pre-baked Stmt and meta-info + // TODO: Nuke intermediate_bufs_ and possibly buf_initializers from here if + // they can be deduced. + LoopNest( + Stmt* stmt, + const std::unordered_set& output_bufs, + const std::unordered_set& intermediate_bufs, + const std::unordered_map& buf_initializers) + : root_stmt_(stmt), + output_bufs_(output_bufs), + intermediate_bufs_(intermediate_bufs), + buf_initializers_(buf_initializers) {} + Stmt* root_stmt() const { return root_stmt_; } From 9dc9a55bc47c98d4683afb969d871615ca45874b Mon Sep 17 00:00:00 2001 From: Jonathan Conder Date: Thu, 8 Oct 2020 01:27:42 -0700 Subject: [PATCH 65/69] Fix TypeError when torch.jit.load is passed a pathlib.Path (#45825) Summary: Fixes https://github.com/pytorch/pytorch/issues/45824 Pull Request resolved: https://github.com/pytorch/pytorch/pull/45825 Reviewed By: VitalyFedyunin Differential Revision: D24129441 Pulled By: gmagogsfm fbshipit-source-id: 52a76e39c163206cee2d19967e333e948adefe99 --- test/jit/test_save_load.py | 18 ++++++++++++++++++ torch/jit/_script.py | 4 ++-- torch/jit/_serialization.py | 2 +- 3 files changed, 21 insertions(+), 3 deletions(-) diff --git a/test/jit/test_save_load.py b/test/jit/test_save_load.py index 9f11731d1864..178db8357e8f 100644 --- a/test/jit/test_save_load.py +++ b/test/jit/test_save_load.py @@ -1,5 +1,6 @@ import os import io +import pathlib import sys import random import torch @@ -920,3 +921,20 @@ def forward(self, a): with self.assertRaises(RuntimeError): extra_files['bar'] = '' torch.jit.load(buffer, _extra_files=extra_files) + + def test_save_load_using_pathlib(self): + class MyMod(torch.jit.ScriptModule): + @torch.jit.script_method + def forward(self, a): + return 2 * a + + m = MyMod() + + # Save then load. + with TemporaryFileName() as fname: + path = pathlib.Path(fname) + m.save(path) + m2 = torch.jit.load(path) + + x = torch.tensor([1., 2., 3., 4.]) + self.assertTrue(torch.equal(m(x), m2(x))) diff --git a/torch/jit/_script.py b/torch/jit/_script.py index 0adbefc02cee..19cce3a86945 100644 --- a/torch/jit/_script.py +++ b/torch/jit/_script.py @@ -478,13 +478,13 @@ def code_with_constants(self): r = self.forward.code_with_constants return (r[0], ConstMap(r[1])) - def save(self, *args, **kwargs): + def save(self, f, **kwargs): r""" save(f, _extra_files={}) See :func:`torch.jit.save ` for details. """ - return self._c.save(*args, **kwargs) + return self._c.save(str(f), **kwargs) def _save_for_lite_interpreter(self, *args, **kwargs): r""" diff --git a/torch/jit/_serialization.py b/torch/jit/_serialization.py index d828ec8a0f1c..fd93cc13aeb6 100644 --- a/torch/jit/_serialization.py +++ b/torch/jit/_serialization.py @@ -158,7 +158,7 @@ def load(f, map_location=None, _extra_files=None): cu = torch._C.CompilationUnit() if isinstance(f, str) or isinstance(f, pathlib.Path): - cpp_module = torch._C.import_ir_module(cu, f, map_location, _extra_files) + cpp_module = torch._C.import_ir_module(cu, str(f), map_location, _extra_files) else: cpp_module = torch._C.import_ir_module_from_buffer( cu, f.read(), map_location, _extra_files From d93cae00f269735c52abd0fa0f41bdc823ba2cb0 Mon Sep 17 00:00:00 2001 From: Ivan Murashko Date: Thu, 8 Oct 2020 01:31:30 -0700 Subject: [PATCH 66/69] [HTE @ clang-tidy] Enable clang-tidy configs inheretence for caffe2 project Summary: The primary HTE configuration (for `HTE@clang-tidy` project) is stored at the parent config `~/fbsource/fbcode.clang-tidy`. The diff enables inheretence of that configuration. Note: `facebook-hte-` checks will not be used until switch to HTE2clang-tidy be made. Note: `clang-diagnostic-*` will start work. As result clang warning messages can be dublicated: one time from HTE and another time from clang-diagnostic Test Plan: N/A Reviewed By: wfarner Differential Revision: D24099167 fbshipit-source-id: 2e092fe678ad3e53a4cef301ce1cb737cf8401e7 --- .clang-tidy | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/.clang-tidy b/.clang-tidy index a540d67a130e..e062760cf75c 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -1,6 +1,7 @@ --- # NOTE there must be no spaces before the '-', so put the comma last. -Checks: '-*, +InheritParentConfig: true +Checks: ' bugprone-*, -bugprone-forward-declaration-namespace, -bugprone-macro-parentheses, @@ -17,6 +18,7 @@ cppcoreguidelines-*, -cppcoreguidelines-pro-type-union-access, -cppcoreguidelines-pro-type-vararg, -cppcoreguidelines-special-member-functions, +-facebook-hte-RelativeInclude, hicpp-exception-baseclass, hicpp-avoid-goto, modernize-*, @@ -27,7 +29,7 @@ modernize-*, -modernize-use-trailing-return-type, performance-*, -performance-noexcept-move-constructor, - ' +' HeaderFilterRegex: 'torch/csrc/.*' AnalyzeTemporaryDtors: false CheckOptions: From 52f2db752d2b29267da356a06ca91e10cd732dbc Mon Sep 17 00:00:00 2001 From: Natalia Gimelshein Date: Thu, 8 Oct 2020 02:13:00 -0700 Subject: [PATCH 67/69] unify reproducibility notes (#45748) Summary: Many of our functions contain same warnings about results reproducibility. Make them use common template. Pull Request resolved: https://github.com/pytorch/pytorch/pull/45748 Reviewed By: colesbury Differential Revision: D24089114 Pulled By: ngimel fbshipit-source-id: e6aa4ce6082f6e0f4ce2713c2bf1864ee1c3712a --- torch/_tensor_docs.py | 22 +--- torch/_torch_docs.py | 22 ++-- torch/nn/functional.py | 106 ++++++----------- torch/nn/modules/conv.py | 250 ++++++++++----------------------------- 4 files changed, 119 insertions(+), 281 deletions(-) diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py index ffa4ffe6100a..fc27da90d3d8 100644 --- a/torch/_tensor_docs.py +++ b/torch/_tensor_docs.py @@ -3,6 +3,7 @@ import torch._C from torch._C import _add_docstr as add_docstr from ._torch_docs import parse_kwargs +from ._torch_docs import reproducibility_notes def add_docstr_all(method, docstr): @@ -1628,12 +1629,7 @@ def add_docstr_all(method, docstr): match :attr:`self`, or an error will be raised. Note: - In some circumstances when using the CUDA backend with CuDNN, this operator - may select a nondeterministic algorithm to increase performance. If this is - undesirable, you can try to make the operation deterministic (potentially at - a performance cost) by setting ``torch.backends.cudnn.deterministic = - True``. - Please see the notes on :doc:`/notes/randomness` for background. + {forward_reproducibility_note} Args: dim (int): dimension along which to index @@ -1651,7 +1647,7 @@ def add_docstr_all(method, docstr): [ 8., 9., 10.], [ 1., 1., 1.], [ 5., 6., 7.]]) -""") +""".format(**reproducibility_notes)) add_docstr_all('index_copy_', r""" @@ -2960,9 +2956,6 @@ def callable(a, b) -> number Reducing with the addition operation is the same as using :meth:`~torch.Tensor.scatter_add_`. -Note: - Reduction is not yet implemented for the CUDA backend. - Args: dim (int): the axis along which to index index (LongTensor): the indices of elements to scatter, @@ -3020,12 +3013,7 @@ def callable(a, b) -> number ``d != dim``. Note: - In some circumstances when using the CUDA backend with CuDNN, this operator - may select a nondeterministic algorithm to increase performance. If this is - undesirable, you can try to make the operation deterministic (potentially at - a performance cost) by setting ``torch.backends.cudnn.deterministic = - True``. - Please see the notes on :doc:`/notes/randomness` for background. + {forward_reproducibility_note} Args: dim (int): the axis along which to index @@ -3045,7 +3033,7 @@ def callable(a, b) -> number [1.0000, 1.0427, 1.0000, 1.6782, 1.0000], [1.7953, 1.0000, 1.6480, 1.0000, 1.9620]]) -""") +""".format(**reproducibility_notes)) add_docstr_all('select', r""" diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 609cd34b2e95..31580e4e0472 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -110,6 +110,19 @@ def merge_dicts(*dicts): "tf32_note": """This operator supports :ref:`TensorFloat32`.""" } + +reproducibility_notes = { + "forward_reproducibility_note": """This operation may behave nondeterministically when given tensors on \ +a CUDA device. See :doc:`/notes/randomness` for more information.""", + "backward_reproducibility_note": """This operation may produce nondeterministic gradients when given tensors on \ +a CUDA device. See :doc:`/notes/randomness` for more information.""", + "cudnn_reproducibility_note": """In some circumstances when given tensors on a CUDA device \ +and using CuDNN, this operator may select a nondeterministic algorithm to increase performance. If this is \ +undesirable, you can try to make the operation deterministic (potentially at \ +a performance cost) by setting ``torch.backends.cudnn.deterministic = True``. \ +See :doc:`/notes/randomness` for more information.""" +} + add_docstr(torch.abs, r""" abs(input, *, out=None) -> Tensor @@ -938,12 +951,7 @@ def merge_dicts(*dicts): ``out[n] += 1``. Note: - In some circumstances when using the CUDA backend with CuDNN, this operator - may select a nondeterministic algorithm to increase performance. If this is - undesirable, you can try to make the operation deterministic (potentially at - a performance cost) by setting ``torch.backends.cudnn.deterministic = - True``. - Please see the notes on :doc:`/notes/randomness` for background. + {backward_reproducibility_note} Arguments: input (Tensor): 1-d int tensor @@ -968,7 +976,7 @@ def merge_dicts(*dicts): >>> input.bincount(weights) tensor([0.0000, 0.0000, 0.0000, 1.0000, 1.0000, 0.0000, 0.5000]) -""") +""".format(**reproducibility_notes)) add_docstr(torch.bitwise_not, r""" diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 72d55d30ad6d..d1575c14323a 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -11,6 +11,7 @@ from torch import _VF from .._jit_internal import boolean_dispatch, List, Optional, _overload, Tuple from ..overrides import has_torch_function, handle_torch_function +from torch._torch_docs import reproducibility_notes, tf32_notes Tensor = torch.Tensor @@ -21,17 +22,13 @@ Applies a 1D convolution over an input signal composed of several input planes. -This operator supports :ref:`TensorFloat32`. +{tf32_note} See :class:`~torch.nn.Conv1d` for details and output shape. Note: - In some circumstances when using the CUDA backend with CuDNN, this operator - may select a nondeterministic algorithm to increase performance. If this is - undesirable, you can try to make the operation deterministic (potentially at - a performance cost) by setting ``torch.backends.cudnn.deterministic = - True``. - Please see the notes on :doc:`/notes/randomness` for background. + {cudnn_reproducibility_note} +""".format(**reproducibility_notes, **tf32_notes) + r""" Args: input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iW)` @@ -59,19 +56,13 @@ Applies a 2D convolution over an input image composed of several input planes. -This operator supports :ref:`TensorFloat32`. +{tf32_note} See :class:`~torch.nn.Conv2d` for details and output shape. Note: - In some circumstances when using the CUDA backend with CuDNN, this operator - may select a nondeterministic algorithm to increase performance. If this is - undesirable, you can try to make the operation deterministic (potentially at - a performance cost) by setting ``torch.backends.cudnn.deterministic = - True``. - Please see the notes on :doc:`/notes/randomness` for background. - - + {cudnn_reproducibility_note} +""".format(**reproducibility_notes, **tf32_notes) + r""" Args: input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)` weight: filters of shape :math:`(\text{out\_channels} , \frac{\text{in\_channels}}{\text{groups}} , kH , kW)` @@ -99,17 +90,13 @@ Applies a 3D convolution over an input image composed of several input planes. -This operator supports :ref:`TensorFloat32`. +{tf32_note} See :class:`~torch.nn.Conv3d` for details and output shape. Note: - In some circumstances when using the CUDA backend with CuDNN, this operator - may select a nondeterministic algorithm to increase performance. If this is - undesirable, you can try to make the operation deterministic (potentially at - a performance cost) by setting ``torch.backends.cudnn.deterministic = - True``. - Please see the notes on :doc:`/notes/randomness` for background. + {cudnn_reproducibility_note} +""".format(**reproducibility_notes, **tf32_notes) + r""" Args: input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iT , iH , iW)` @@ -137,17 +124,13 @@ Applies a 1D transposed convolution operator over an input signal composed of several input planes, sometimes also called "deconvolution". -This operator supports :ref:`TensorFloat32`. +{tf32_note} See :class:`~torch.nn.ConvTranspose1d` for details and output shape. Note: - In some circumstances when using the CUDA backend with CuDNN, this operator - may select a nondeterministic algorithm to increase performance. If this is - undesirable, you can try to make the operation deterministic (potentially at - a performance cost) by setting ``torch.backends.cudnn.deterministic = - True``. - Please see the notes on :doc:`/notes/randomness` for background. + {cudnn_reproducibility_note} +""".format(**reproducibility_notes, **tf32_notes) + r""" Args: input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iW)` @@ -178,17 +161,13 @@ Applies a 2D transposed convolution operator over an input image composed of several input planes, sometimes also called "deconvolution". -This operator supports :ref:`TensorFloat32`. +{tf32_note} See :class:`~torch.nn.ConvTranspose2d` for details and output shape. Note: - In some circumstances when using the CUDA backend with CuDNN, this operator - may select a nondeterministic algorithm to increase performance. If this is - undesirable, you can try to make the operation deterministic (potentially at - a performance cost) by setting ``torch.backends.cudnn.deterministic = - True``. - Please see the notes on :doc:`/notes/randomness` for background. + {cudnn_reproducibility_note} +""".format(**reproducibility_notes, **tf32_notes) + r""" Args: input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)` @@ -221,17 +200,13 @@ Applies a 3D transposed convolution operator over an input image composed of several input planes, sometimes also called "deconvolution" -This operator supports :ref:`TensorFloat32`. +{tf32_note} See :class:`~torch.nn.ConvTranspose3d` for details and output shape. Note: - In some circumstances when using the CUDA backend with CuDNN, this operator - may select a nondeterministic algorithm to increase performance. If this is - undesirable, you can try to make the operation deterministic (potentially at - a performance cost) by setting ``torch.backends.cudnn.deterministic = - True``. - Please see the notes on :doc:`/notes/randomness` for background. + {cudnn_reproducibility_note} +""".format(**reproducibility_notes, **tf32_notes) + r""" Args: input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iT , iH , iW)` @@ -1831,6 +1806,7 @@ def embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2., [ 0.0000, 0.0000, 0.0000], [ 0.6262, 0.2438, 0.7471]]]) """ + if padding_idx is not None: if padding_idx > 0: assert padding_idx < weight.size(0), 'Padding_idx must be within num_embeddings' @@ -1862,9 +1838,7 @@ def embedding_bag(input, weight, offsets=None, max_norm=None, norm_type=2, See :class:`torch.nn.EmbeddingBag` for more details. Note: - When using the CUDA backend, this operation may induce nondeterministic - behaviour in its backward pass that is not easily switched off. - Please see the notes on :doc:`/notes/randomness` for background. + {backward_reproducibility_note} Args: input (LongTensor): Tensor containing bags of indices into the embedding matrix @@ -1932,6 +1906,7 @@ def embedding_bag(input, weight, offsets=None, max_norm=None, norm_type=2, tensor([[ 0.3397, 0.3552, 0.5545], [ 0.5893, 0.4386, 0.5882]]) """ + if not torch.jit.is_scripting(): tens_ops = (input, weight) if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): @@ -2018,6 +1993,8 @@ def embedding_bag(input, weight, offsets=None, max_norm=None, norm_type=2, include_last_offset) return ret +embedding_bag.__doc__ = embedding_bag.__doc__.format(**reproducibility_notes) + def _verify_batch_size(size): # type: (List[int]) -> None @@ -2152,17 +2129,10 @@ def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0, See :class:`~torch.nn.CTCLoss` for details. Note: - In some circumstances when using the CUDA backend with CuDNN, this operator - may select a nondeterministic algorithm to increase performance. If this is - undesirable, you can try to make the operation deterministic (potentially at - a performance cost) by setting ``torch.backends.cudnn.deterministic = - True``. - Please see the notes on :doc:`/notes/randomness` for background. + {cudnn_reproducibility_note} Note: - When using the CUDA backend, this operation may induce nondeterministic - behaviour in its backward pass that is not easily switched off. - Please see the notes on :doc:`/notes/randomness` for background. + {backward_reproducibility_note} Args: log_probs: :math:`(T, N, C)` where `C = number of characters in alphabet including blank`, @@ -2199,6 +2169,7 @@ def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0, """ return torch.ctc_loss(log_probs, targets, input_lengths, target_lengths, blank, _Reduction.get_enum(reduction), zero_infinity) +ctc_loss.__doc__ = ctc_loss.__doc__.format(**reproducibility_notes) def nll_loss(input, target, weight=None, size_average=None, ignore_index=-100, @@ -2901,9 +2872,7 @@ def upsample(input, size=None, scale_factor=None, mode='nearest', align_corners= This is equivalent with ``nn.functional.interpolate(...)``. Note: - When using the CUDA backend, this operation may induce nondeterministic - behaviour in its backward pass that is not easily switched off. - Please see the notes on :doc:`/notes/randomness` for background. + {backward_reproducibility_note} The algorithm used for upsampling is determined by :attr:`mode`. @@ -2953,6 +2922,7 @@ def upsample(input, size=None, scale_factor=None, mode='nearest', align_corners= """ warnings.warn("nn.functional.upsample is deprecated. Use nn.functional.interpolate instead.") return interpolate(input, size, scale_factor, mode, align_corners) +upsample.__doc__ = upsample.__doc__.format(**reproducibility_notes) @_overload # noqa: F811 def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None): # noqa: F811 @@ -3042,9 +3012,7 @@ def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corne calculation. Note: - When using the CUDA backend, this operation may induce nondeterministic - behaviour in its backward pass that is not easily switched off. - Please see the notes on :doc:`/notes/randomness` for background. + {backward_reproducibility_note} """ if not torch.jit.is_scripting(): if type(input) is not Tensor and has_torch_function((input,)): @@ -3174,6 +3142,7 @@ def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corne raise NotImplementedError("Input Error: Only 3D, 4D and 5D input Tensors supported" " (got {}D) for the modes: nearest | linear | bilinear | bicubic | trilinear" " (got {})".format(input.dim(), mode)) +interpolate.__doc__ = interpolate.__doc__.format(**reproducibility_notes) @_overload # noqa: F811 def upsample_nearest(input, size=None, scale_factor=None): # noqa: F811 @@ -3202,13 +3171,12 @@ def upsample_nearest(input, size=None, scale_factor=None): # noqa: F811 scale_factor (int): multiplier for spatial size. Has to be an integer. Note: - When using the CUDA backend, this operation may induce nondeterministic - behaviour in its backward pass that is not easily switched off. - Please see the notes on :doc:`/notes/randomness` for background. + {backward_reproducibility_note} """ # DeprecationWarning is ignored by default warnings.warn("nn.functional.upsample_nearest is deprecated. Use nn.functional.interpolate instead.") return interpolate(input, size, scale_factor, mode='nearest') +upsample_nearest.__doc__ = upsample_nearest.__doc__.format(**reproducibility_notes) @_overload # noqa: F811 def upsample_bilinear(input, size=None, scale_factor=None): # noqa: F811 @@ -3247,14 +3215,12 @@ def upsample_bilinear(input, size=None, scale_factor=None): # noqa: F811 scale_factor (int or Tuple[int, int]): multiplier for spatial size Note: - When using the CUDA backend, this operation may induce nondeterministic - behaviour in its backward pass that is not easily switched off. - Please see the notes on :doc:`/notes/randomness` for background. + {backward_reproducibility_note} """ # DeprecationWarning is ignored by default warnings.warn("nn.functional.upsample_bilinear is deprecated. Use nn.functional.interpolate instead.") return interpolate(input, size, scale_factor, mode='bilinear', align_corners=True) - +upsample_bilinear.__doc__ = upsample_bilinear.__doc__.format(**reproducibility_notes) GRID_SAMPLE_INTERPOLATION_MODES = { 'bilinear': 0, @@ -3801,7 +3767,7 @@ def assert_int_or_pair(arg, arg_name, message): def unfold(input, kernel_size, dilation=1, padding=0, stride=1): # type: (Tensor, BroadcastingList2[int], BroadcastingList2[int], BroadcastingList2[int], BroadcastingList2[int]) -> Tensor # noqa - r"""Extracts sliding local blocks from an batched input tensor. + r"""Extracts sliding local blocks from a batched input tensor. .. warning:: Currently, only 4-D input tensors (batched image-like tensors) are diff --git a/torch/nn/modules/conv.py b/torch/nn/modules/conv.py index 3b9391d1061c..7280eab37caa 100644 --- a/torch/nn/modules/conv.py +++ b/torch/nn/modules/conv.py @@ -9,10 +9,35 @@ from .. import init from .module import Module from .utils import _single, _pair, _triple, _reverse_repeat_tuple +from torch._torch_docs import reproducibility_notes from ..common_types import _size_1_t, _size_2_t, _size_3_t from typing import Optional, List, Tuple +convolution_notes = \ + {"groups_note": """* :attr:`groups` controls the connections between inputs and outputs. + :attr:`in_channels` and :attr:`out_channels` must both be divisible by + :attr:`groups`. For example, + + * At groups=1, all inputs are convolved to all outputs. + * At groups=2, the operation becomes equivalent to having two conv + layers side by side, each seeing half the input channels + and producing half the output channels, and both subsequently + concatenated. + * At groups= :attr:`in_channels`, each input channel is convolved with + its own set of filters (of size + :math:`\\frac{\\text{out\_channels}}{\\text{in\_channels}}`).""", # noqa: W605 + + "depthwise_separable_note": """When `groups == in_channels` and `out_channels == K * in_channels`, + where `K` is a positive integer, this operation is also known as a "depthwise convolution". + + In other words, for an input of size :math:`(N, C_{in}, L_{in})`, + a depthwise convolution with a depthwise multiplier `K` can be performed with the arguments + :math:`(C_\\text{in}=C_\\text{in}, C_\\text{out}=C_\\text{in} \\times \\text{K}, ..., \\text{groups}=C_\\text{in})`."""} # noqa: W605 + + + + class _ConvNd(Module): @@ -113,7 +138,7 @@ def __setstate__(self, state): class Conv1d(_ConvNd): - r"""Applies a 1D convolution over an input signal composed of several input + __doc__ = r"""Applies a 1D convolution over an input signal composed of several input planes. In the simplest case, the output value of the layer with input size @@ -128,58 +153,26 @@ class Conv1d(_ConvNd): where :math:`\star` is the valid `cross-correlation`_ operator, :math:`N` is a batch size, :math:`C` denotes a number of channels, :math:`L` is a length of signal sequence. + """ + r""" This module supports :ref:`TensorFloat32`. * :attr:`stride` controls the stride for the cross-correlation, a single number or a one-element tuple. - * :attr:`padding` controls the amount of implicit zero-paddings on both sides + * :attr:`padding` controls the amount of implicit padding on both sides for :attr:`padding` number of points. * :attr:`dilation` controls the spacing between the kernel points; also known as the à trous algorithm. It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. - * :attr:`groups` controls the connections between inputs and outputs. - :attr:`in_channels` and :attr:`out_channels` must both be divisible by - :attr:`groups`. For example, - - * At groups=1, all inputs are convolved to all outputs. - * At groups=2, the operation becomes equivalent to having two conv - layers side by side, each seeing half the input channels, - and producing half the output channels, and both subsequently - concatenated. - * At groups= :attr:`in_channels`, each input channel is convolved with - its own set of filters, - of size - :math:`\left\lfloor\frac{out\_channels}{in\_channels}\right\rfloor`. + {groups_note} Note: - - Depending of the size of your kernel, several (of the last) - columns of the input might be lost, because it is a valid - `cross-correlation`_, and not a full `cross-correlation`_. - It is up to the user to add proper padding. - + {depthwise_separable_note} Note: - - When `groups == in_channels` and `out_channels == K * in_channels`, - where `K` is a positive integer, this operation is also termed in - literature as depthwise convolution. - - In other words, for an input of size :math:`(N, C_{in}, L_{in})`, - a depthwise convolution with a depthwise multiplier `K`, can be constructed by arguments - :math:`(C_\text{in}=C_{in}, C_\text{out}=C_{in} \times K, ..., \text{groups}=C_{in})`. - - Note: - In some circumstances when using the CUDA backend with CuDNN, this operator - may select a nondeterministic algorithm to increase performance. If this is - undesirable, you can try to make the operation deterministic (potentially at - a performance cost) by setting ``torch.backends.cudnn.deterministic = - True``. - Please see the notes on :doc:`/notes/randomness` for background. - + {cudnn_reproducibility_note} Args: in_channels (int): Number of channels in the input image @@ -197,6 +190,8 @@ class Conv1d(_ConvNd): bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` + """.format(**reproducibility_notes, **convolution_notes) + r""" + Shape: - Input: :math:`(N, C_{in}, L_{in})` - Output: :math:`(N, C_{out}, L_{out})` where @@ -260,7 +255,7 @@ def forward(self, input: Tensor) -> Tensor: class Conv2d(_ConvNd): - r"""Applies a 2D convolution over an input signal composed of several input + __doc__ = r"""Applies a 2D convolution over an input signal composed of several input planes. In the simplest case, the output value of the layer with input size @@ -276,31 +271,21 @@ class Conv2d(_ConvNd): :math:`N` is a batch size, :math:`C` denotes a number of channels, :math:`H` is a height of input planes in pixels, and :math:`W` is width in pixels. + """ + r""" This module supports :ref:`TensorFloat32`. * :attr:`stride` controls the stride for the cross-correlation, a single number or a tuple. - * :attr:`padding` controls the amount of implicit zero-paddings on both + * :attr:`padding` controls the amount of implicit padding on both sides for :attr:`padding` number of points for each dimension. * :attr:`dilation` controls the spacing between the kernel points; also known as the à trous algorithm. It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. - * :attr:`groups` controls the connections between inputs and outputs. - :attr:`in_channels` and :attr:`out_channels` must both be divisible by - :attr:`groups`. For example, - - * At groups=1, all inputs are convolved to all outputs. - * At groups=2, the operation becomes equivalent to having two conv - layers side by side, each seeing half the input channels, - and producing half the output channels, and both subsequently - concatenated. - * At groups= :attr:`in_channels`, each input channel is convolved with - its own set of filters, of size: - :math:`\left\lfloor\frac{out\_channels}{in\_channels}\right\rfloor`. + {groups_note} The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be: @@ -309,30 +294,10 @@ class Conv2d(_ConvNd): and the second `int` for the width dimension Note: - - Depending of the size of your kernel, several (of the last) - columns of the input might be lost, because it is a valid `cross-correlation`_, - and not a full `cross-correlation`_. - It is up to the user to add proper padding. - - Note: - - When `groups == in_channels` and `out_channels == K * in_channels`, - where `K` is a positive integer, this operation is also termed in - literature as depthwise convolution. - - In other words, for an input of size :math:`(N, C_{in}, H_{in}, W_{in})`, - a depthwise convolution with a depthwise multiplier `K`, can be constructed by arguments - :math:`(in\_channels=C_{in}, out\_channels=C_{in} \times K, ..., groups=C_{in})`. + {depthwise_separable_note} Note: - In some circumstances when using the CUDA backend with CuDNN, this operator - may select a nondeterministic algorithm to increase performance. If this is - undesirable, you can try to make the operation deterministic (potentially at - a performance cost) by setting ``torch.backends.cudnn.deterministic = - True``. - Please see the notes on :doc:`/notes/randomness` for background. - + {cudnn_reproducibility_note} Args: in_channels (int): Number of channels in the input image @@ -348,6 +313,7 @@ class Conv2d(_ConvNd): channels to output channels. Default: 1 bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` + """.format(**reproducibility_notes, **convolution_notes) + r""" Shape: - Input: :math:`(N, C_{in}, H_{in}, W_{in})` @@ -391,6 +357,7 @@ class Conv2d(_ConvNd): .. _link: https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md """ + def __init__( self, in_channels: int, @@ -423,7 +390,7 @@ def forward(self, input: Tensor) -> Tensor: return self._conv_forward(input, self.weight) class Conv3d(_ConvNd): - r"""Applies a 3D convolution over an input signal composed of several input + __doc__ = r"""Applies a 3D convolution over an input signal composed of several input planes. In the simplest case, the output value of the layer with input size :math:`(N, C_{in}, D, H, W)` @@ -434,29 +401,19 @@ class Conv3d(_ConvNd): \sum_{k = 0}^{C_{in} - 1} weight(C_{out_j}, k) \star input(N_i, k) where :math:`\star` is the valid 3D `cross-correlation`_ operator + """ + r""" This module supports :ref:`TensorFloat32`. * :attr:`stride` controls the stride for the cross-correlation. - * :attr:`padding` controls the amount of implicit zero-paddings on both + * :attr:`padding` controls the amount of implicit padding on both sides for :attr:`padding` number of points for each dimension. * :attr:`dilation` controls the spacing between the kernel points; also known as the à trous algorithm. It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. - * :attr:`groups` controls the connections between inputs and outputs. - :attr:`in_channels` and :attr:`out_channels` must both be divisible by - :attr:`groups`. For example, - - * At groups=1, all inputs are convolved to all outputs. - * At groups=2, the operation becomes equivalent to having two conv - layers side by side, each seeing half the input channels, - and producing half the output channels, and both subsequently - concatenated. - * At groups= :attr:`in_channels`, each input channel is convolved with - its own set of filters, of size - :math:`\left\lfloor\frac{out\_channels}{in\_channels}\right\rfloor`. + {groups_note} The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be: @@ -465,30 +422,10 @@ class Conv3d(_ConvNd): the second `int` for the height dimension and the third `int` for the width dimension Note: - - Depending of the size of your kernel, several (of the last) - columns of the input might be lost, because it is a valid `cross-correlation`_, - and not a full `cross-correlation`_. - It is up to the user to add proper padding. + {depthwise_separable_note} Note: - - When `groups == in_channels` and `out_channels == K * in_channels`, - where `K` is a positive integer, this operation is also termed in - literature as depthwise convolution. - - In other words, for an input of size :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`, - a depthwise convolution with a depthwise multiplier `K`, can be constructed by arguments - :math:`(in\_channels=C_{in}, out\_channels=C_{in} \times K, ..., groups=C_{in})`. - - Note: - In some circumstances when using the CUDA backend with CuDNN, this operator - may select a nondeterministic algorithm to increase performance. If this is - undesirable, you can try to make the operation deterministic (potentially at - a performance cost) by setting ``torch.backends.cudnn.deterministic = - True``. - Please see the notes on :doc:`/notes/randomness` for background. - + {cudnn_reproducibility_note} Args: in_channels (int): Number of channels in the input image @@ -500,6 +437,7 @@ class Conv3d(_ConvNd): dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` + """.format(**reproducibility_notes, **convolution_notes) + r""" Shape: - Input: :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})` @@ -544,6 +482,7 @@ class Conv3d(_ConvNd): .. _link: https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md """ + def __init__( self, in_channels: int, @@ -628,7 +567,7 @@ def _output_padding(self, input, output_size, stride, padding, kernel_size, dila class ConvTranspose1d(_ConvTransposeNd): - r"""Applies a 1D transposed convolution operator over an input image + __doc__ = r"""Applies a 1D transposed convolution operator over an input image composed of several input planes. This module can be seen as the gradient of Conv1d with respect to its input. @@ -639,7 +578,7 @@ class ConvTranspose1d(_ConvTransposeNd): * :attr:`stride` controls the stride for the cross-correlation. - * :attr:`padding` controls the amount of implicit zero-paddings on both + * :attr:`padding` controls the amount of implicit zero padding on both sides for ``dilation * (kernel_size - 1) - padding`` number of points. See note below for details. @@ -649,25 +588,7 @@ class ConvTranspose1d(_ConvTransposeNd): * :attr:`dilation` controls the spacing between the kernel points; also known as the à trous algorithm. It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. - * :attr:`groups` controls the connections between inputs and outputs. - :attr:`in_channels` and :attr:`out_channels` must both be divisible by - :attr:`groups`. For example, - - * At groups=1, all inputs are convolved to all outputs. - * At groups=2, the operation becomes equivalent to having two conv - layers side by side, each seeing half the input channels, - and producing half the output channels, and both subsequently - concatenated. - * At groups= :attr:`in_channels`, each input channel is convolved with - its own set of filters (of size - :math:`\left\lfloor\frac{out\_channels}{in\_channels}\right\rfloor`). - - Note: - - Depending of the size of your kernel, several (of the last) - columns of the input might be lost, because it is a valid `cross-correlation`_, - and not a full `cross-correlation`_. - It is up to the user to add proper padding. + {groups_note} Note: The :attr:`padding` argument effectively adds ``dilation * (kernel_size - 1) - padding`` @@ -702,6 +623,7 @@ class ConvTranspose1d(_ConvTransposeNd): groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 + """.format(**reproducibility_notes, **convolution_notes) + r""" Shape: - Input: :math:`(N, C_{in}, L_{in})` @@ -764,7 +686,7 @@ def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Ten class ConvTranspose2d(_ConvTransposeNd): - r"""Applies a 2D transposed convolution operator over an input image + __doc__ = r"""Applies a 2D transposed convolution operator over an input image composed of several input planes. This module can be seen as the gradient of Conv2d with respect to its input. @@ -775,7 +697,7 @@ class ConvTranspose2d(_ConvTransposeNd): * :attr:`stride` controls the stride for the cross-correlation. - * :attr:`padding` controls the amount of implicit zero-paddings on both + * :attr:`padding` controls the amount of implicit zero padding on both sides for ``dilation * (kernel_size - 1) - padding`` number of points. See note below for details. @@ -785,18 +707,7 @@ class ConvTranspose2d(_ConvTransposeNd): * :attr:`dilation` controls the spacing between the kernel points; also known as the à trous algorithm. It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. - * :attr:`groups` controls the connections between inputs and outputs. - :attr:`in_channels` and :attr:`out_channels` must both be divisible by - :attr:`groups`. For example, - - * At groups=1, all inputs are convolved to all outputs. - * At groups=2, the operation becomes equivalent to having two conv - layers side by side, each seeing half the input channels, - and producing half the output channels, and both subsequently - concatenated. - * At groups= :attr:`in_channels`, each input channel is convolved with - its own set of filters (of size - :math:`\left\lfloor\frac{out\_channels}{in\_channels}\right\rfloor`). + {groups_note} The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`output_padding` can either be: @@ -805,13 +716,6 @@ class ConvTranspose2d(_ConvTransposeNd): - a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension, and the second `int` for the width dimension - .. note:: - - Depending of the size of your kernel, several (of the last) - columns of the input might be lost, because it is a valid `cross-correlation`_, - and not a full `cross-correlation`_. - It is up to the user to add proper padding. - Note: The :attr:`padding` argument effectively adds ``dilation * (kernel_size - 1) - padding`` amount of zero padding to both sizes of the input. This is set so that @@ -825,13 +729,7 @@ class ConvTranspose2d(_ConvTransposeNd): not actually add zero-padding to output. Note: - In some circumstances when using the CUDA backend with CuDNN, this operator - may select a nondeterministic algorithm to increase performance. If this is - undesirable, you can try to make the operation deterministic (potentially at - a performance cost) by setting ``torch.backends.cudnn.deterministic = - True``. - Please see the notes on :doc:`/notes/randomness` for background. - + {cudnn_reproducibility_note} Args: in_channels (int): Number of channels in the input image @@ -845,6 +743,7 @@ class ConvTranspose2d(_ConvTransposeNd): groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 + """.format(**reproducibility_notes, **convolution_notes) + r""" Shape: - Input: :math:`(N, C_{in}, H_{in}, W_{in})` @@ -930,7 +829,7 @@ def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Ten class ConvTranspose3d(_ConvTransposeNd): - r"""Applies a 3D transposed convolution operator over an input image composed of several input + __doc__ = r"""Applies a 3D transposed convolution operator over an input image composed of several input planes. The transposed convolution operator multiplies each input value element-wise by a learnable kernel, and sums over the outputs from all input feature planes. @@ -943,7 +842,7 @@ class ConvTranspose3d(_ConvTransposeNd): * :attr:`stride` controls the stride for the cross-correlation. - * :attr:`padding` controls the amount of implicit zero-paddings on both + * :attr:`padding` controls the amount of implicit zero padding on both sides for ``dilation * (kernel_size - 1) - padding`` number of points. See note below for details. @@ -953,18 +852,7 @@ class ConvTranspose3d(_ConvTransposeNd): * :attr:`dilation` controls the spacing between the kernel points; also known as the à trous algorithm. It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. - * :attr:`groups` controls the connections between inputs and outputs. - :attr:`in_channels` and :attr:`out_channels` must both be divisible by - :attr:`groups`. For example, - - * At groups=1, all inputs are convolved to all outputs. - * At groups=2, the operation becomes equivalent to having two conv - layers side by side, each seeing half the input channels, - and producing half the output channels, and both subsequently - concatenated. - * At groups= :attr:`in_channels`, each input channel is convolved with - its own set of filters (of size - :math:`\left\lfloor\frac{out\_channels}{in\_channels}\right\rfloor`). + {groups_note} The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`output_padding` can either be: @@ -973,13 +861,6 @@ class ConvTranspose3d(_ConvTransposeNd): - a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension, the second `int` for the height dimension and the third `int` for the width dimension - Note: - - Depending of the size of your kernel, several (of the last) - columns of the input might be lost, because it is a valid `cross-correlation`_, - and not a full `cross-correlation`_. - It is up to the user to add proper padding. - Note: The :attr:`padding` argument effectively adds ``dilation * (kernel_size - 1) - padding`` amount of zero padding to both sizes of the input. This is set so that @@ -993,13 +874,7 @@ class ConvTranspose3d(_ConvTransposeNd): not actually add zero-padding to output. Note: - In some circumstances when using the CUDA backend with CuDNN, this operator - may select a nondeterministic algorithm to increase performance. If this is - undesirable, you can try to make the operation deterministic (potentially at - a performance cost) by setting ``torch.backends.cudnn.deterministic = - True``. - Please see the notes on :doc:`/notes/randomness` for background. - + {cudnn_reproducibility_note} Args: in_channels (int): Number of channels in the input image @@ -1013,6 +888,7 @@ class ConvTranspose3d(_ConvTransposeNd): groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 + """.format(**reproducibility_notes, **convolution_notes) + r""" Shape: - Input: :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})` From acca11b89808c8d57e914057a6c0e60cb66cea46 Mon Sep 17 00:00:00 2001 From: Taras Galkovskyi Date: Thu, 8 Oct 2020 06:13:25 -0700 Subject: [PATCH 68/69] [torchscript] Verbose logging of code location causing the error (#45908) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45908 As per subj, existing logging does not explain the cause of the error Test Plan: unit tests pass. Reviewed By: SplitInfinity Differential Revision: D23609965 fbshipit-source-id: 818965176f7193c62035e3d2f0547bb525fea0fb --- torch/jit/frontend.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch/jit/frontend.py b/torch/jit/frontend.py index 36dccd04b7e3..5949a38d4cc7 100644 --- a/torch/jit/frontend.py +++ b/torch/jit/frontend.py @@ -228,7 +228,7 @@ def _forward(self): dedent_src = dedent(source) py_ast = ast.parse(dedent_src) if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef): - raise RuntimeError("Expected a single top-level function") + raise RuntimeError(f"Expected a single top-level function: {filename}:{file_lineno}") leading_whitespace_len = len(source.split('\n', 1)[0]) - len(dedent_src.split('\n', 1)[0]) type_line = torch.jit.annotations.get_type_line(source) ctx = SourceContext(source, filename, file_lineno, leading_whitespace_len, True) @@ -238,7 +238,7 @@ def _forward(self): if should_drop(fn): unused_fn_def = ast.parse("def unused_fn(self: Any):\n\traise RuntimeError(\"Cannot call @unused methods\")") if len(unused_fn_def.body) != 1 or not isinstance(unused_fn_def.body[0], ast.FunctionDef): - raise RuntimeError("Expected a single top-level function") + raise RuntimeError(f"Expected a single top-level function: {filename}:{file_lineno}") unused_def = unused_fn_def.body[0] fn_def.body = unused_def.body # kwarg/vararg not supported by `build_def` From 7d4f5060ade84bf79bec2d3827b42fa3c3737123 Mon Sep 17 00:00:00 2001 From: Shijun Kong Date: Thu, 8 Oct 2020 09:11:25 -0700 Subject: [PATCH 69/69] Fix doc about operator benchmark (#45853) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45853 The method name in README is not consistent with actual implementation. Reviewed By: qizzzh Differential Revision: D24114849 fbshipit-source-id: d979e324c768708e99b8cc5b87e261f17c22a883 --- benchmarks/operator_benchmark/README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/benchmarks/operator_benchmark/README.md b/benchmarks/operator_benchmark/README.md index 95e0e46bf79a..9cdd46a4ea21 100644 --- a/benchmarks/operator_benchmark/README.md +++ b/benchmarks/operator_benchmark/README.md @@ -344,7 +344,7 @@ class UnaryOpBenchmark(op_bench.TorchBenchmarkBase):     def forward(self):         return self.op_func(self.input_one) -op_bench.generate_pt_tests_from_list(unary_ops_list, unary_ops_configs, UnaryOpBenchmark) +op_bench.generate_pt_tests_from_op_list(unary_ops_list, unary_ops_configs, UnaryOpBenchmark) if __name__ == "__main__":     op_bench.benchmark_runner.main() @@ -388,10 +388,10 @@ class UnaryOpBenchmark(op_bench.TorchBenchmarkBase): ``` #### Part 3. Register a List of Operators -To register multiple operators, we introduced the `generate_pt_tests_from_list` function which takes three parameters. First, the list of operators. Second,the configs. Third, the benchmark class.   +To register multiple operators, we introduced the `generate_pt_tests_from_op_list` function which takes three parameters. First, the list of operators. Second,the configs. Third, the benchmark class.   Here is an example: ``` -op_bench.generate_pt_tests_from_list(unary_ops_list, unary_ops_configs, UnaryOpBenchmark) +op_bench.generate_pt_tests_from_op_list(unary_ops_list, unary_ops_configs, UnaryOpBenchmark) ```