From 0650a6166ff42a65b431f527d0f9a76f5be44e37 Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Tue, 10 Nov 2020 23:27:21 -0800 Subject: [PATCH 01/93] [c10d] switch ProcessGroup::Work to be managed by intrusive_ptr (#44046) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/44046 Test Plan: Imported from OSS Reviewed By: gmagogsfm Differential Revision: D23632280 Pulled By: wanchaol fbshipit-source-id: 0a4642a8ffabdd26c52c1baabfa30c0f446c3c85 --- test/cpp_extensions/cpp_c10d_extension.cpp | 32 +++--- test/cpp_extensions/cpp_c10d_extension.hpp | 26 ++--- torch/csrc/distributed/c10d/init.cpp | 5 +- .../distributed/rpc/process_group_agent.cpp | 2 +- .../distributed/rpc/process_group_agent.h | 4 +- torch/lib/c10d/ProcessGroup.cpp | 2 +- torch/lib/c10d/ProcessGroup.hpp | 35 +++--- torch/lib/c10d/ProcessGroupGloo.cpp | 101 +++++++++--------- torch/lib/c10d/ProcessGroupGloo.hpp | 38 +++---- torch/lib/c10d/ProcessGroupMPI.cpp | 42 ++++---- torch/lib/c10d/ProcessGroupMPI.hpp | 36 +++---- torch/lib/c10d/ProcessGroupNCCL.cpp | 52 ++++----- torch/lib/c10d/ProcessGroupNCCL.hpp | 46 ++++---- torch/lib/c10d/ProcessGroupRoundRobin.cpp | 30 +++--- torch/lib/c10d/ProcessGroupRoundRobin.hpp | 30 +++--- torch/lib/c10d/comm.cpp | 4 +- torch/lib/c10d/example/allreduce.cpp | 2 +- torch/lib/c10d/reducer.cpp | 2 +- torch/lib/c10d/reducer.hpp | 9 +- .../c10d/test/ProcessGroupGlooAsyncTest.cpp | 10 +- torch/lib/c10d/test/ProcessGroupGlooTest.cpp | 16 +-- torch/lib/c10d/test/ProcessGroupMPITest.cpp | 38 +++---- .../c10d/test/ProcessGroupNCCLErrorsTest.cpp | 8 +- torch/lib/c10d/test/ProcessGroupNCCLTest.cpp | 12 +-- 24 files changed, 295 insertions(+), 287 deletions(-) diff --git a/test/cpp_extensions/cpp_c10d_extension.cpp b/test/cpp_extensions/cpp_c10d_extension.cpp index b4901cdbcf4d..50e5f5861caa 100644 --- a/test/cpp_extensions/cpp_c10d_extension.cpp +++ b/test/cpp_extensions/cpp_c10d_extension.cpp @@ -23,85 +23,85 @@ ProcessGroupTest::ProcessGroupTest(int rank, int size) ProcessGroupTest::~ProcessGroupTest() {} -std::shared_ptr ProcessGroupTest::broadcast( +c10::intrusive_ptr ProcessGroupTest::broadcast( std::vector& tensors, const BroadcastOptions& opts) { - return std::make_shared(); + return c10::make_intrusive(); } -std::shared_ptr ProcessGroupTest::allreduce( +c10::intrusive_ptr ProcessGroupTest::allreduce( std::vector& tensors, const AllreduceOptions& opts) { - return std::make_shared(); + return c10::make_intrusive(); } -std::shared_ptr ProcessGroupTest::allreduce_coalesced( +c10::intrusive_ptr ProcessGroupTest::allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts) { throw std::runtime_error("ProcessGroupTest does not support allreduce_coalesced"); } -std::shared_ptr ProcessGroupTest::reduce( +c10::intrusive_ptr ProcessGroupTest::reduce( std::vector& tensors, const ReduceOptions& opts) { throw std::runtime_error("ProcessGroupTest does not support reduce"); } -std::shared_ptr ProcessGroupTest::allgather( +c10::intrusive_ptr ProcessGroupTest::allgather( std::vector>& outputTensors, std::vector& inputTensors, const AllgatherOptions& opts) { throw std::runtime_error("ProcessGroupTest does not support allgather"); } -std::shared_ptr ProcessGroupTest::allgather_base( +c10::intrusive_ptr ProcessGroupTest::allgather_base( at::Tensor& outputBuffer, at::Tensor& inputBuffer, const AllgatherOptions& opts) { throw std::runtime_error("ProcessGroupTest does not support allgather_base"); } -std::shared_ptr ProcessGroupTest::barrier( +c10::intrusive_ptr ProcessGroupTest::barrier( const BarrierOptions& opts) { - return std::make_shared(); + return c10::make_intrusive(); } -std::shared_ptr ProcessGroupTest::gather( +c10::intrusive_ptr ProcessGroupTest::gather( std::vector>& outputTensors, std::vector& inputTensors, const GatherOptions& opts) { throw std::runtime_error("ProcessGroupTest does not support gather"); } -std::shared_ptr ProcessGroupTest::scatter( +c10::intrusive_ptr ProcessGroupTest::scatter( std::vector& outputTensors, std::vector>& inputTensors, const ScatterOptions& opts) { throw std::runtime_error("ProcessGroupTest does not support scatter"); } -std::shared_ptr ProcessGroupTest::reduce_scatter( +c10::intrusive_ptr ProcessGroupTest::reduce_scatter( std::vector& outputTensors, std::vector>& inputTensors, const ReduceScatterOptions& opts) { throw std::runtime_error("ProcessGroupTest does not support reduce_scatter"); } -std::shared_ptr ProcessGroupTest::send( +c10::intrusive_ptr ProcessGroupTest::send( std::vector& tensors, int dstRank, int tag) { throw std::runtime_error("ProcessGroupTest does not support send"); } -std::shared_ptr ProcessGroupTest::recv( +c10::intrusive_ptr ProcessGroupTest::recv( std::vector& tensors, int srcRank, int tag) { throw std::runtime_error("ProcessGroupTest does not support recv"); } -std::shared_ptr ProcessGroupTest::recvAnysource( +c10::intrusive_ptr ProcessGroupTest::recvAnysource( std::vector& tensor, int tag) { throw std::runtime_error("ProcessGroupTest does not support recvAnysource"); diff --git a/test/cpp_extensions/cpp_c10d_extension.hpp b/test/cpp_extensions/cpp_c10d_extension.hpp index d8dffcd20327..8aeec736d440 100644 --- a/test/cpp_extensions/cpp_c10d_extension.hpp +++ b/test/cpp_extensions/cpp_c10d_extension.hpp @@ -41,61 +41,61 @@ class ProcessGroupTest : public ProcessGroup { explicit ProcessGroupTest(int rank = -1, int size = -1); virtual ~ProcessGroupTest(); - std::shared_ptr broadcast( + c10::intrusive_ptr broadcast( std::vector& data, const BroadcastOptions& opts = BroadcastOptions()) override; - std::shared_ptr allreduce( + c10::intrusive_ptr allreduce( std::vector& tensors, const AllreduceOptions& opts = AllreduceOptions()) override; - std::shared_ptr allreduce_coalesced( + c10::intrusive_ptr allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts = AllreduceCoalescedOptions()) override; - std::shared_ptr reduce( + c10::intrusive_ptr reduce( std::vector& tensors, const ReduceOptions& opts = ReduceOptions()) override; - std::shared_ptr allgather( + c10::intrusive_ptr allgather( std::vector>& outputTensors, std::vector& inputTensors, const AllgatherOptions& opts = AllgatherOptions()) override; - std::shared_ptr allgather_base( + c10::intrusive_ptr allgather_base( at::Tensor& outputBuffer, at::Tensor& inputBuffer, const AllgatherOptions& opts = AllgatherOptions()) override; - std::shared_ptr barrier( + c10::intrusive_ptr barrier( const BarrierOptions& opts = BarrierOptions()) override; - std::shared_ptr gather( + c10::intrusive_ptr gather( std::vector>& outputTensors, std::vector& inputTensors, const GatherOptions& opts = GatherOptions()) override; - std::shared_ptr scatter( + c10::intrusive_ptr scatter( std::vector& outputTensors, std::vector>& inputTensors, const ScatterOptions& opts = ScatterOptions()) override; - std::shared_ptr reduce_scatter( + c10::intrusive_ptr reduce_scatter( std::vector& outputTensors, std::vector>& inputTensors, const ReduceScatterOptions& opts = ReduceScatterOptions()) override; - std::shared_ptr send( + c10::intrusive_ptr send( std::vector& tensors, int dstRank, int tag); - std::shared_ptr recv( + c10::intrusive_ptr recv( std::vector& tensors, int srcRank, int tag); - std::shared_ptr recvAnysource( + c10::intrusive_ptr recvAnysource( std::vector& tensor, int tag); diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index d9ddf35ee1df..136efd32fc87 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -1,5 +1,6 @@ #include +#include #include #ifndef _WIN32 #include @@ -59,6 +60,8 @@ constexpr auto kDeprecationWarning = "{} API is being deprecated, please ping " "https://github.com/pytorch/pytorch/issues/46291 " "if you see this warning"; +template +using intrusive_ptr_class_ = py::class_>; // PythonStore is a pybind11 trampoline class to allow a Python // class to inherit from c10d.Store and implement its interface. @@ -1045,7 +1048,7 @@ that adds a prefix to each key inserted to the store. py::call_guard()); #endif - shared_ptr_class_<::c10d::ProcessGroup::Work>(module, "Work") + intrusive_ptr_class_<::c10d::ProcessGroup::Work>(module, "Work") .def("is_completed", &::c10d::ProcessGroup::Work::isCompleted) .def( "is_success", diff --git a/torch/csrc/distributed/rpc/process_group_agent.cpp b/torch/csrc/distributed/rpc/process_group_agent.cpp index 2f29adc8f0c4..13e685b8fe74 100644 --- a/torch/csrc/distributed/rpc/process_group_agent.cpp +++ b/torch/csrc/distributed/rpc/process_group_agent.cpp @@ -398,7 +398,7 @@ void ProcessGroupAgent::handleSend(const SendWork& work) { // ProcessGroup is not thread-safe when sending with the same tag, // hence the lock - std::vector> pendingSends; + std::vector> pendingSends; const auto dst = work.to_.id_; // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) diff --git a/torch/csrc/distributed/rpc/process_group_agent.h b/torch/csrc/distributed/rpc/process_group_agent.h index 1bc8db9ebf20..70fb1b40244d 100644 --- a/torch/csrc/distributed/rpc/process_group_agent.h +++ b/torch/csrc/distributed/rpc/process_group_agent.h @@ -230,14 +230,14 @@ class TORCH_API ProcessGroupAgent : public RpcAgent { // Lock and shared ptr to currently pending work, set in listenloop() and // interruptible in shutdown(). std::mutex recvWorkMutex_; - std::shared_ptr recvWork_; + c10::intrusive_ptr recvWork_; // Map of dst rank to current oustanding sends that we are waiting on. In the // case of a call to ::shutdown() while we are still waiting on these sends, // the pending sends contained in this map will be aborted, allowing the // waiting thread to be unblocked. std::unordered_map< worker_id_t, - std::set>> + std::set>> currentPendingSends_; // Lock to serialize access to the above map. std::mutex pendingSendMutex_; diff --git a/torch/lib/c10d/ProcessGroup.cpp b/torch/lib/c10d/ProcessGroup.cpp index 3521ed42c840..1d0d451f21a9 100644 --- a/torch/lib/c10d/ProcessGroup.cpp +++ b/torch/lib/c10d/ProcessGroup.cpp @@ -164,7 +164,7 @@ ProcessGroup::~ProcessGroup() {} // This is introduced so that implementors of ProcessGroup would not need to // have this implmentation. -std::shared_ptr ProcessGroup::allgather_coalesced( +c10::intrusive_ptr ProcessGroup::allgather_coalesced( std::vector>& /* usused */, std::vector& /* usused */, const AllgatherOptions& /* usused */) { diff --git a/torch/lib/c10d/ProcessGroup.hpp b/torch/lib/c10d/ProcessGroup.hpp index 5e90dccc25c0..63996b516a06 100644 --- a/torch/lib/c10d/ProcessGroup.hpp +++ b/torch/lib/c10d/ProcessGroup.hpp @@ -70,12 +70,11 @@ bool isP2POp(OpType opType); // class ProcessGroup { public: - // Please do not use ProcessGroup::Work API, it is going away, to be // replaced by ivalue::Future. // Python binding for this class might change, please do not assume // this will be bound using pybind. - class Work { + class Work : public torch::CustomClassHolder { public: Work(int rank = -1, OpType opType = OpType::UNKNOWN, const char* profilingTitle = nullptr); @@ -171,25 +170,25 @@ class ProcessGroup { return size_; } - virtual std::shared_ptr broadcast( + virtual c10::intrusive_ptr broadcast( std::vector& data, const BroadcastOptions& opts = BroadcastOptions()) = 0; - virtual std::shared_ptr allreduce( + virtual c10::intrusive_ptr allreduce( std::vector& data, const AllreduceOptions& opts = AllreduceOptions()) = 0; // This will be moved out of ProcessGroup, do not add dependencies on this // function. - virtual std::shared_ptr allreduce_coalesced( + virtual c10::intrusive_ptr allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts = AllreduceCoalescedOptions()) = 0; - virtual std::shared_ptr reduce( + virtual c10::intrusive_ptr reduce( std::vector& tensors, const ReduceOptions& opts = ReduceOptions()) = 0; - virtual std::shared_ptr allgather( + virtual c10::intrusive_ptr allgather( std::vector>& outputTensors, std::vector& inputTensors, const AllgatherOptions& opts = AllgatherOptions()) = 0; @@ -197,7 +196,7 @@ class ProcessGroup { // Gathers a single tensor inputBuffer into a single buffer outputBuffer that // is interpreted as a contigious collection of size inputBuffer * WORLD_SIZE. // For implementers of ProcessGroup API and advanced users only. - virtual std::shared_ptr allgather_base( + virtual c10::intrusive_ptr allgather_base( at::Tensor& outputBuffer, at::Tensor& inputBuffer, const AllgatherOptions& opts = AllgatherOptions()) = 0; @@ -206,27 +205,27 @@ class ProcessGroup { // * do not add dependencies on this function, // * do not implement it in your ProcessGroup, implement allgather_base // instead. - virtual std::shared_ptr allgather_coalesced( + virtual c10::intrusive_ptr allgather_coalesced( std::vector>& outputTensorLists, std::vector& inputTensors, const AllgatherOptions& opts = AllgatherOptions()); - virtual std::shared_ptr gather( + virtual c10::intrusive_ptr gather( std::vector>& outputTensors, std::vector& inputTensors, const GatherOptions& opts = GatherOptions()) = 0; - virtual std::shared_ptr scatter( + virtual c10::intrusive_ptr scatter( std::vector& outputTensors, std::vector>& inputTensors, const ScatterOptions& opts = ScatterOptions()) = 0; - virtual std::shared_ptr reduce_scatter( + virtual c10::intrusive_ptr reduce_scatter( std::vector& outputTensors, std::vector>& inputTensors, const ReduceScatterOptions& opts = ReduceScatterOptions()) = 0; - virtual std::shared_ptr alltoall_base( + virtual c10::intrusive_ptr alltoall_base( at::Tensor& outputTensor, at::Tensor& inputTensor, std::vector& outputSplitSizes, @@ -235,28 +234,28 @@ class ProcessGroup { throw std::runtime_error("ProcessGroup does not support alltoall"); } - virtual std::shared_ptr alltoall( + virtual c10::intrusive_ptr alltoall( std::vector& outputTensors, std::vector& inputTensors, const AllToAllOptions& opts = AllToAllOptions()) { throw std::runtime_error("ProcessGroup does not support alltoall"); } - virtual std::shared_ptr send( + virtual c10::intrusive_ptr send( std::vector& tensors, int dstRank, int tag) = 0; - virtual std::shared_ptr recv( + virtual c10::intrusive_ptr recv( std::vector& tensors, int srcRank, int tag) = 0; - virtual std::shared_ptr recvAnysource( + virtual c10::intrusive_ptr recvAnysource( std::vector& tensors, int tag) = 0; - virtual std::shared_ptr barrier( + virtual c10::intrusive_ptr barrier( const BarrierOptions& opts = BarrierOptions()) = 0; protected: diff --git a/torch/lib/c10d/ProcessGroupGloo.cpp b/torch/lib/c10d/ProcessGroupGloo.cpp index cd3e83e6b714..90c9b695de28 100644 --- a/torch/lib/c10d/ProcessGroupGloo.cpp +++ b/torch/lib/c10d/ProcessGroupGloo.cpp @@ -38,6 +38,7 @@ #endif #include +#include #include #include #include @@ -653,11 +654,11 @@ void ProcessGroupGloo::runLoop(int workerIndex) { AsyncWork::execute(std::move(work)); lock.lock(); - workInProgress_[workerIndex] = nullptr; + workInProgress_[workerIndex].reset(); } } -void ProcessGroupGloo::enqueue(std::shared_ptr work) { +void ProcessGroupGloo::enqueue(c10::intrusive_ptr work) { std::unique_lock lock(workMutex_); workQueue_.push_back(std::move(work)); lock.unlock(); @@ -773,7 +774,7 @@ class AsyncBroadcastCUDAWork : public AsyncBroadcastWork { } // namespace -std::shared_ptr ProcessGroupGloo::broadcast( +c10::intrusive_ptr ProcessGroupGloo::broadcast( std::vector& inputs, const BroadcastOptions& opts) { static auto invalidArgument = [](const std::string& msg) { @@ -796,15 +797,15 @@ std::shared_ptr ProcessGroupGloo::broadcast( invalidArgument(c10::str("unsupported device type ", device.type())); } - std::shared_ptr work; + c10::intrusive_ptr work; auto tag = nextTag(); auto context = getContext(tag); if (device.type() == at::kCPU) { - work = std::make_shared( + work = c10::make_intrusive( std::move(context), inputs, opts.rootRank, opts.rootTensor, tag); #ifdef USE_CUDA } else if (device.type() == at::kCUDA) { - work = std::make_shared( + work = c10::make_intrusive( std::move(context), inputs, opts.rootRank, opts.rootTensor, tag); #endif } else { @@ -1300,7 +1301,7 @@ class AsyncSparseAllreduceCUDAWork : public AsyncSparseAllreduceWork { } // namespace -std::shared_ptr ProcessGroupGloo::allreduce( +c10::intrusive_ptr ProcessGroupGloo::allreduce( std::vector& inputs, const AllreduceOptions& opts) { static auto invalidArgument = [](const std::string& msg) { @@ -1329,15 +1330,15 @@ std::shared_ptr ProcessGroupGloo::allreduce( "(allreduce of sparse tensors only works with ReduceOp.SUM)"); } - std::shared_ptr work; + c10::intrusive_ptr work; auto tag = nextTag(); auto context = getContext(tag); if (device.type() == at::kCPU) { if (layout == c10::kStrided) { - work = std::make_shared( + work = c10::make_intrusive( std::move(context), inputs, opts.reduceOp, tag); } else if (layout == c10::kSparse) { - work = std::make_shared( + work = c10::make_intrusive( std::move(context), inputs, tag); } else { invalidArgument("unsupported layout"); @@ -1345,10 +1346,10 @@ std::shared_ptr ProcessGroupGloo::allreduce( #ifdef USE_CUDA } else if (device.type() == at::kCUDA) { if (layout == c10::kStrided) { - work = std::make_shared( + work = c10::make_intrusive( std::move(context), inputs, opts.reduceOp, tag); } else if (layout == c10::kSparse) { - work = std::make_shared( + work = c10::make_intrusive( std::move(context), inputs, tag); } else { invalidArgument("unsupported layout"); @@ -1362,7 +1363,7 @@ std::shared_ptr ProcessGroupGloo::allreduce( return work; } -std::shared_ptr ProcessGroupGloo::allreduce_coalesced( +c10::intrusive_ptr ProcessGroupGloo::allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts) { static auto invalidArgument = [](const std::string& msg) { @@ -1405,12 +1406,12 @@ std::shared_ptr ProcessGroupGloo::allreduce_coalesced( invalidArgument("unsupported layout"); } - std::shared_ptr work; + c10::intrusive_ptr work; const uint32_t tag = nextTag(); std::shared_ptr context = getContext(tag); if (device.type() == c10::kCPU) { if (layout == c10::kStrided) { - work = std::make_shared( + work = c10::make_intrusive( std::move(context), tensors, opts.reduceOp, tag); } else { invalidArgument("unsupported layout"); @@ -1538,7 +1539,7 @@ class AsyncReduceCUDAWork : public AsyncReduceWork { } // namespace -std::shared_ptr ProcessGroupGloo::reduce( +c10::intrusive_ptr ProcessGroupGloo::reduce( std::vector& inputs, const ReduceOptions& opts) { static auto invalidArgument = [](const std::string& msg) { @@ -1561,11 +1562,11 @@ std::shared_ptr ProcessGroupGloo::reduce( invalidArgument(c10::str("unsupported device type ", device.type())); } - std::shared_ptr work; + c10::intrusive_ptr work; auto tag = nextTag(); auto context = getContext(tag); if (device.type() == at::kCPU) { - work = std::make_shared( + work = c10::make_intrusive( std::move(context), inputs, opts.rootRank, @@ -1574,7 +1575,7 @@ std::shared_ptr ProcessGroupGloo::reduce( tag); #ifdef USE_CUDA } else if (device.type() == at::kCUDA) { - work = std::make_shared( + work = c10::make_intrusive( std::move(context), inputs, opts.rootRank, @@ -1720,7 +1721,7 @@ class AsyncAllgatherCUDAWork : public AsyncAllgatherWork { // Note: current CUDA implementation holds the assumption that the // tensors in the nested output tensor vectors are on the same device. -std::shared_ptr ProcessGroupGloo::allgather( +c10::intrusive_ptr ProcessGroupGloo::allgather( std::vector>& outputs, std::vector& inputs, const AllgatherOptions& opts) { @@ -1769,15 +1770,15 @@ std::shared_ptr ProcessGroupGloo::allgather( invalidArgument(c10::str("unsupported device type ", device.type())); } - std::shared_ptr work; + c10::intrusive_ptr work; auto tag = nextTag(); auto context = getContext(tag); if (device.type() == at::kCPU) { - work = std::make_shared( + work = c10::make_intrusive( std::move(context), outputs, inputs, tag); #ifdef USE_CUDA } else if (device.type() == at::kCUDA) { - work = std::make_shared( + work = c10::make_intrusive( std::move(context), outputs, inputs, tag); #endif } else { @@ -1852,7 +1853,7 @@ class AsyncAllgatherCoalescedWork : public ProcessGroupGloo::AsyncWork { } // namespace -std::shared_ptr ProcessGroupGloo::allgather_coalesced( +c10::intrusive_ptr ProcessGroupGloo::allgather_coalesced( std::vector>& output_lists, std::vector& input_list, const AllgatherOptions& /* unused */) { @@ -1902,13 +1903,13 @@ std::shared_ptr ProcessGroupGloo::allgather_coalesced( auto tag = nextTag(); auto context = getContext(tag); - auto work = std::make_shared( + auto work = c10::make_intrusive( std::move(context), output_lists, input_list, tag); enqueue(work); return work; } -std::shared_ptr ProcessGroupGloo::allgather_base( +c10::intrusive_ptr ProcessGroupGloo::allgather_base( at::Tensor& /*unused */, at::Tensor& /*unused */, const AllgatherOptions& /*unused */) { @@ -2057,7 +2058,7 @@ class AsyncGatherCUDAWork : public AsyncGatherWork { } // namespace -std::shared_ptr ProcessGroupGloo::gather( +c10::intrusive_ptr ProcessGroupGloo::gather( std::vector>& outputs, std::vector& inputs, const GatherOptions& opts) { @@ -2103,15 +2104,15 @@ std::shared_ptr ProcessGroupGloo::gather( invalidArgument(c10::str("unsupported device type ", device.type())); } - std::shared_ptr work; + c10::intrusive_ptr work; auto tag = nextTag(); auto context = getContext(tag); if (device.type() == at::kCPU) { - work = std::make_shared( + work = c10::make_intrusive( std::move(context), outputs, inputs, opts.rootRank, tag); #ifdef USE_CUDA } else if (device.type() == at::kCUDA) { - work = std::make_shared( + work = c10::make_intrusive( std::move(context), outputs, inputs, opts.rootRank, tag); #endif } else { @@ -2245,7 +2246,7 @@ class AsyncScatterCUDAWork : public AsyncScatterWork { } // namespace -std::shared_ptr ProcessGroupGloo::scatter( +c10::intrusive_ptr ProcessGroupGloo::scatter( std::vector& outputs, std::vector>& inputs, const ScatterOptions& opts) { @@ -2290,15 +2291,15 @@ std::shared_ptr ProcessGroupGloo::scatter( invalidArgument(c10::str("unsupported device type ", device.type())); } - std::shared_ptr work; + c10::intrusive_ptr work; auto tag = nextTag(); auto context = getContext(tag); if (device.type() == at::kCPU) { - work = std::make_shared( + work = c10::make_intrusive( std::move(context), outputs, inputs, opts.rootRank, tag); #ifdef USE_CUDA } else if (device.type() == at::kCUDA) { - work = std::make_shared( + work = c10::make_intrusive( std::move(context), outputs, inputs, opts.rootRank, tag); #endif } else { @@ -2308,7 +2309,7 @@ std::shared_ptr ProcessGroupGloo::scatter( return work; } -std::shared_ptr ProcessGroupGloo::reduce_scatter( +c10::intrusive_ptr ProcessGroupGloo::reduce_scatter( std::vector& outputs, std::vector>& inputs, const ReduceScatterOptions& opts) { @@ -2443,7 +2444,7 @@ class AsyncAlltoallCUDAWork : public AsyncAlltoallWork { } // namespace -std::shared_ptr ProcessGroupGloo::alltoall_base( +c10::intrusive_ptr ProcessGroupGloo::alltoall_base( at::Tensor& outputTensor, at::Tensor& inputTensor, std::vector& outputCounts, @@ -2460,12 +2461,12 @@ std::shared_ptr ProcessGroupGloo::alltoall_base( assertDense(invalidArgument, {inputTensor}); const auto& device = outputTensor.device(); - std::shared_ptr work; + c10::intrusive_ptr work; auto tag = nextTag(); auto context = getContext(tag); if (device.type() == at::kCPU) { - work = std::make_shared( + work = c10::make_intrusive( std::move(context), outputTensor, inputTensor, @@ -2474,7 +2475,7 @@ std::shared_ptr ProcessGroupGloo::alltoall_base( tag); #ifdef USE_CUDA } else if (device.type() == at::kCUDA) { - work = std::make_shared( + work = c10::make_intrusive( std::move(context), outputTensor, inputTensor, @@ -2510,7 +2511,7 @@ uint32_t checkTag(int32_t tag) { return (uint32_t)tag; } -std::shared_ptr ProcessGroupGloo::send( +c10::intrusive_ptr ProcessGroupGloo::send( std::vector& tensors, int dstRank, int tag) { @@ -2526,10 +2527,10 @@ std::shared_ptr ProcessGroupGloo::send( // The work captures the tensor to prevent it being deallocated and // the unbound buffer to synchronize on completion of the send. - return std::make_shared(tensor, std::move(buf)); + return c10::make_intrusive(tensor, std::move(buf)); } -std::shared_ptr ProcessGroupGloo::recv( +c10::intrusive_ptr ProcessGroupGloo::recv( std::vector& tensors, int srcRank, int tag) { @@ -2545,10 +2546,10 @@ std::shared_ptr ProcessGroupGloo::recv( // The work captures the tensor to prevent it being deallocated and // the unbound buffer to synchronize on completion of the recv. - return std::make_shared(tensor, std::move(buf)); + return c10::make_intrusive(tensor, std::move(buf)); } -std::shared_ptr ProcessGroupGloo::recvAnysource( +c10::intrusive_ptr ProcessGroupGloo::recvAnysource( std::vector& tensors, int tag) { auto& tensor = checkSingleTensor(tensors); @@ -2573,7 +2574,7 @@ std::shared_ptr ProcessGroupGloo::recvAnysource( // The work captures the tensor to prevent it being deallocated and // the unbound buffer to synchronize on completion of the recv. - return std::make_shared(tensor, std::move(buf)); + return c10::make_intrusive(tensor, std::move(buf)); } namespace { @@ -2582,13 +2583,13 @@ class AsyncBarrierWork : public ProcessGroupGloo::AsyncWork { public: AsyncBarrierWork( const std::shared_ptr& context, - std::vector> priorWork, + std::vector> priorWork, uint32_t tag) : ProcessGroupGloo::AsyncWork("gloo:barrier"), context(context), priorWork(std::move(priorWork)), tag(tag) {} std::shared_ptr context; - std::vector> priorWork; + std::vector> priorWork; const uint32_t tag; void run() override { @@ -2608,9 +2609,9 @@ class AsyncBarrierWork : public ProcessGroupGloo::AsyncWork { } // namespace -std::shared_ptr ProcessGroupGloo::barrier( +c10::intrusive_ptr ProcessGroupGloo::barrier( const BarrierOptions& opts) { - std::vector> priorWork; + std::vector> priorWork; // Snapshot all in progress and pending work as weak_ptr. // When executing a barrier, we need to ensure that all prior work @@ -2624,7 +2625,7 @@ std::shared_ptr ProcessGroupGloo::barrier( auto tag = nextTag(); auto context = getContext(tag); - auto work = std::make_shared( + auto work = c10::make_intrusive( std::move(context), std::move(priorWork), tag); enqueue(work); return work; diff --git a/torch/lib/c10d/ProcessGroupGloo.hpp b/torch/lib/c10d/ProcessGroupGloo.hpp index 31664ad0b6cf..74fd0f6e5165 100644 --- a/torch/lib/c10d/ProcessGroupGloo.hpp +++ b/torch/lib/c10d/ProcessGroupGloo.hpp @@ -70,7 +70,7 @@ class ProcessGroupGloo : public ProcessGroup { public: AsyncWork(const char* profilingTitle = nullptr): ProcessGroup::Work(-1, OpType::UNKNOWN, profilingTitle) {} - static void execute(std::shared_ptr work) { + static void execute(c10::intrusive_ptr work) { std::exception_ptr eptr; try { work->run(); @@ -159,75 +159,75 @@ class ProcessGroupGloo : public ProcessGroup { virtual ~ProcessGroupGloo(); - std::shared_ptr broadcast( + c10::intrusive_ptr broadcast( std::vector& tensors, const BroadcastOptions& opts = BroadcastOptions()) override; - std::shared_ptr allreduce( + c10::intrusive_ptr allreduce( std::vector& tensors, const AllreduceOptions& opts = AllreduceOptions()) override; - std::shared_ptr allreduce_coalesced( + c10::intrusive_ptr allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts = AllreduceCoalescedOptions()) override; - std::shared_ptr reduce( + c10::intrusive_ptr reduce( std::vector& tensors, const ReduceOptions& opts = ReduceOptions()) override; - std::shared_ptr allgather( + c10::intrusive_ptr allgather( std::vector>& outputs, std::vector& inputs, const AllgatherOptions& opts = AllgatherOptions()) override; - std::shared_ptr allgather_base( + c10::intrusive_ptr allgather_base( at::Tensor& outputBuffer, at::Tensor& inputBuffer, const AllgatherOptions& opts = AllgatherOptions()) override; - std::shared_ptr allgather_coalesced( + c10::intrusive_ptr allgather_coalesced( std::vector>& output_lists, std::vector& input_list, const AllgatherOptions& opts = AllgatherOptions()) override; - std::shared_ptr gather( + c10::intrusive_ptr gather( std::vector>& outputs, std::vector& inputs, const GatherOptions& opts = GatherOptions()) override; - std::shared_ptr scatter( + c10::intrusive_ptr scatter( std::vector& outputs, std::vector>& inputs, const ScatterOptions& opts = ScatterOptions()) override; - std::shared_ptr reduce_scatter( + c10::intrusive_ptr reduce_scatter( std::vector& outputs, std::vector>& inputs, const ReduceScatterOptions& opts = ReduceScatterOptions()) override; - std::shared_ptr alltoall_base( + c10::intrusive_ptr alltoall_base( at::Tensor& outputTensor, at::Tensor& inputTensor, std::vector& outputCounts, std::vector& inputCounts, const AllToAllOptions& opts = AllToAllOptions()) override; - std::shared_ptr send( + c10::intrusive_ptr send( std::vector& tensors, int dstRank, int tag) override; - std::shared_ptr recv( + c10::intrusive_ptr recv( std::vector& tensors, int srcRank, int tag) override; - std::shared_ptr recvAnysource( + c10::intrusive_ptr recvAnysource( std::vector& tensors, int tag) override; - std::shared_ptr barrier( + c10::intrusive_ptr barrier( const BarrierOptions& opts = BarrierOptions()) override; protected: @@ -258,7 +258,7 @@ class ProcessGroupGloo : public ProcessGroup { void runLoop(int workerIndex); // Queue work to run on worker thread. - void enqueue(std::shared_ptr work); + void enqueue(c10::intrusive_ptr work); // Keep both a queue of pending work, and a vector with in progress work. // Both of these can only be mutated when holding the queue lock. @@ -266,8 +266,8 @@ class ProcessGroupGloo : public ProcessGroup { // to all in progress and pending work when executing a barrier. // When executing a barrier, we need to ensure that all prior work // has completed before completing itself. - std::deque> workQueue_; - std::vector> workInProgress_; + std::deque> workQueue_; + std::vector> workInProgress_; std::mutex workMutex_; std::condition_variable workProduceCV_; std::condition_variable workConsumeCV_; diff --git a/torch/lib/c10d/ProcessGroupMPI.cpp b/torch/lib/c10d/ProcessGroupMPI.cpp index d3e79a1dd424..5f9d0be41b8f 100644 --- a/torch/lib/c10d/ProcessGroupMPI.cpp +++ b/torch/lib/c10d/ProcessGroupMPI.cpp @@ -308,9 +308,9 @@ void ProcessGroupMPI::runLoop() { } } -std::shared_ptr ProcessGroupMPI::enqueue( +c10::intrusive_ptr ProcessGroupMPI::enqueue( std::unique_ptr entry) { - auto work = std::make_shared(); + auto work = c10::make_intrusive(); std::unique_lock lock(pgMutex_); queue_.push_back(std::make_tuple(std::move(entry), work)); lock.unlock(); @@ -318,7 +318,7 @@ std::shared_ptr ProcessGroupMPI::enqueue( return work; } -std::shared_ptr ProcessGroupMPI::broadcast( +c10::intrusive_ptr ProcessGroupMPI::broadcast( std::vector& tensors, const BroadcastOptions& opts) { checkSingleTensor(tensors); @@ -339,7 +339,7 @@ std::shared_ptr ProcessGroupMPI::broadcast( return enqueue(std::move(entry)); } -std::shared_ptr ProcessGroupMPI::allreduce( +c10::intrusive_ptr ProcessGroupMPI::allreduce( std::vector& tensors, const AllreduceOptions& opts) { checkSingleTensor(tensors); @@ -362,14 +362,14 @@ std::shared_ptr ProcessGroupMPI::allreduce( return enqueue(std::move(entry)); } -std::shared_ptr ProcessGroupMPI::allreduce_coalesced( +c10::intrusive_ptr ProcessGroupMPI::allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts) { throw std::runtime_error( "allreduce_coalesced is currently not supported with MPI"); } -std::shared_ptr ProcessGroupMPI::reduce( +c10::intrusive_ptr ProcessGroupMPI::reduce( std::vector& tensors, const ReduceOptions& opts) { checkSingleTensor(tensors); @@ -397,7 +397,7 @@ std::shared_ptr ProcessGroupMPI::reduce( return enqueue(std::move(entry)); } -std::shared_ptr ProcessGroupMPI::allgather( +c10::intrusive_ptr ProcessGroupMPI::allgather( std::vector>& outputTensors, std::vector& inputTensors, const AllgatherOptions& opts) { @@ -441,7 +441,7 @@ std::shared_ptr ProcessGroupMPI::allgather( return enqueue(std::move(entry)); } -std::shared_ptr ProcessGroupMPI::allgather_coalesced( +c10::intrusive_ptr ProcessGroupMPI::allgather_coalesced( std::vector>& /* unused */, std::vector& /* unused */, const AllgatherOptions& /* unused */) { @@ -449,7 +449,7 @@ std::shared_ptr ProcessGroupMPI::allgather_coalesced( "ProcessGroupMPI does not support allgather_coalesced"); } -std::shared_ptr ProcessGroupMPI::gather( +c10::intrusive_ptr ProcessGroupMPI::gather( std::vector>& outputTensors, std::vector& inputTensors, const GatherOptions& opts) { @@ -516,7 +516,7 @@ std::shared_ptr ProcessGroupMPI::gather( } } -std::shared_ptr ProcessGroupMPI::scatter( +c10::intrusive_ptr ProcessGroupMPI::scatter( std::vector& outputTensors, std::vector>& inputTensors, const ScatterOptions& opts) { @@ -582,14 +582,14 @@ std::shared_ptr ProcessGroupMPI::scatter( } } -std::shared_ptr ProcessGroupMPI::reduce_scatter( +c10::intrusive_ptr ProcessGroupMPI::reduce_scatter( std::vector& outputTensors, std::vector>& inputTensors, const ReduceScatterOptions& opts) { throw std::runtime_error("ProcessGroupMPI does not support reduce_scatter"); } -std::shared_ptr ProcessGroupMPI::alltoall_base( +c10::intrusive_ptr ProcessGroupMPI::alltoall_base( at::Tensor& outputTensor, at::Tensor& inputTensor, std::vector& outputSplitSizes, @@ -665,7 +665,7 @@ std::shared_ptr ProcessGroupMPI::alltoall_base( return enqueue(std::move(entry)); } } -std::shared_ptr ProcessGroupMPI::alltoall( +c10::intrusive_ptr ProcessGroupMPI::alltoall( std::vector& outputTensors, std::vector& inputTensors, const AllToAllOptions& opts) { @@ -722,7 +722,7 @@ std::shared_ptr ProcessGroupMPI::alltoall( return enqueue(std::move(entry)); } -std::shared_ptr ProcessGroupMPI::send( +c10::intrusive_ptr ProcessGroupMPI::send( std::vector& tensors, int dstRank, int tag) { @@ -744,10 +744,10 @@ std::shared_ptr ProcessGroupMPI::send( &request)); } - return std::make_shared(tensor, request); + return c10::make_intrusive(tensor, request); } -std::shared_ptr ProcessGroupMPI::recv( +c10::intrusive_ptr ProcessGroupMPI::recv( std::vector& tensors, int srcRank, int tag) { @@ -769,10 +769,10 @@ std::shared_ptr ProcessGroupMPI::recv( &request)); } - return std::make_shared(tensor, request); + return c10::make_intrusive(tensor, request); } -std::shared_ptr ProcessGroupMPI::recvAnysource( +c10::intrusive_ptr ProcessGroupMPI::recvAnysource( std::vector& tensors, int tag) { checkSingleTensor(tensors); @@ -793,10 +793,10 @@ std::shared_ptr ProcessGroupMPI::recvAnysource( &request)); } - return std::make_shared(tensor, request); + return c10::make_intrusive(tensor, request); } -std::shared_ptr ProcessGroupMPI::barrier( +c10::intrusive_ptr ProcessGroupMPI::barrier( const BarrierOptions& opts) { std::function&)> runFunc = [this](std::unique_ptr& entry) { @@ -808,7 +808,7 @@ std::shared_ptr ProcessGroupMPI::barrier( return enqueue(std::move(entry)); } -std::shared_ptr ProcessGroupMPI::allgather_base( +c10::intrusive_ptr ProcessGroupMPI::allgather_base( at::Tensor& /*unused */, at::Tensor& /*unused */, const AllgatherOptions& /*unused */) { diff --git a/torch/lib/c10d/ProcessGroupMPI.hpp b/torch/lib/c10d/ProcessGroupMPI.hpp index 342fe87001a0..48d95eada887 100644 --- a/torch/lib/c10d/ProcessGroupMPI.hpp +++ b/torch/lib/c10d/ProcessGroupMPI.hpp @@ -108,80 +108,80 @@ class ProcessGroupMPI : public ProcessGroup { // Abort the MPI program, needs to be called when exception is detected void abort(); - std::shared_ptr broadcast( + c10::intrusive_ptr broadcast( std::vector& data, const BroadcastOptions& opts = BroadcastOptions()) override; - std::shared_ptr allreduce( + c10::intrusive_ptr allreduce( std::vector& tensors, const AllreduceOptions& opts = AllreduceOptions()) override; - std::shared_ptr allreduce_coalesced( + c10::intrusive_ptr allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts = AllreduceCoalescedOptions()) override; - std::shared_ptr reduce( + c10::intrusive_ptr reduce( std::vector& tensors, const ReduceOptions& opts = ReduceOptions()) override; - std::shared_ptr allgather( + c10::intrusive_ptr allgather( std::vector>& outputTensors, std::vector& inputTensors, const AllgatherOptions& opts = AllgatherOptions()) override; - std::shared_ptr allgather_base( + c10::intrusive_ptr allgather_base( at::Tensor& outputbuffer, at::Tensor& inputbuffer, const AllgatherOptions& opts = AllgatherOptions()) override; - std::shared_ptr allgather_coalesced( + c10::intrusive_ptr allgather_coalesced( std::vector>& outputTensorLists, std::vector& inputTensors, const AllgatherOptions& opts = AllgatherOptions()) override; - std::shared_ptr gather( + c10::intrusive_ptr gather( std::vector>& outputTensors, std::vector& inputTensors, const GatherOptions& opts = GatherOptions()) override; - std::shared_ptr scatter( + c10::intrusive_ptr scatter( std::vector& outputTensors, std::vector>& inputTensors, const ScatterOptions& opts = ScatterOptions()) override; - std::shared_ptr reduce_scatter( + c10::intrusive_ptr reduce_scatter( std::vector& outputTensors, std::vector>& inputTensors, const ReduceScatterOptions& opts = ReduceScatterOptions()) override; - std::shared_ptr alltoall_base( + c10::intrusive_ptr alltoall_base( at::Tensor& outputTensor, at::Tensor& inputTensor, std::vector& outputSplitSizes, std::vector& inputSplitSizes, const AllToAllOptions& opts = AllToAllOptions()) override; - std::shared_ptr alltoall( + c10::intrusive_ptr alltoall( std::vector& outputTensors, std::vector& inputTensors, const AllToAllOptions& opts = AllToAllOptions()) override; - std::shared_ptr send( + c10::intrusive_ptr send( std::vector& tensors, int dstRank, int tag); - std::shared_ptr recv( + c10::intrusive_ptr recv( std::vector& tensors, int srcRank, int tag); - std::shared_ptr recvAnysource( + c10::intrusive_ptr recvAnysource( std::vector& tensor, int tag); - std::shared_ptr barrier( + c10::intrusive_ptr barrier( const BarrierOptions& opts = BarrierOptions()) override; // Creating a new ProcessGroupMPI, will initiialize MPI if not initialized @@ -190,13 +190,13 @@ class ProcessGroupMPI : public ProcessGroup { protected: using WorkType = - std::tuple, std::shared_ptr>; + std::tuple, c10::intrusive_ptr>; // Worker thread loop void runLoop(); // Helper function that is called by the destructor void destroy(); - std::shared_ptr enqueue(std::unique_ptr entry); + c10::intrusive_ptr enqueue(std::unique_ptr entry); bool stop_; diff --git a/torch/lib/c10d/ProcessGroupNCCL.cpp b/torch/lib/c10d/ProcessGroupNCCL.cpp index ba0b4b36c77d..bd1563226343 100644 --- a/torch/lib/c10d/ProcessGroupNCCL.cpp +++ b/torch/lib/c10d/ProcessGroupNCCL.cpp @@ -984,12 +984,12 @@ std::vector flatten_for_scatter_gather( } // namespace -std::shared_ptr ProcessGroupNCCL::initWork( +c10::intrusive_ptr ProcessGroupNCCL::initWork( std::vector devices, int rank, OpType opType, const char* profilingTitle) { - return std::make_shared(devices, rank, opType, profilingTitle); + return c10::make_intrusive(devices, rank, opType); } std::vector ProcessGroupNCCL::WorkNCCL::result() { @@ -1012,7 +1012,7 @@ c10::intrusive_ptr ProcessGroupNCCL::WorkNCCL:: } void ProcessGroupNCCL::workEnqueue( - std::shared_ptr work) { + c10::intrusive_ptr work) { if (!terminateProcessGroup_.load()) { std::lock_guard lock(workMetaListMutex_); // Avoid view tensors to be processed in cleanup thread. @@ -1027,7 +1027,7 @@ ProcessGroupNCCL::Options::Options() isHighPriorityStream(false) {} template -std::shared_ptr ProcessGroupNCCL::collective( +c10::intrusive_ptr ProcessGroupNCCL::collective( std::vector& inputs, std::vector& outputs, Fn fn, @@ -1114,7 +1114,7 @@ std::shared_ptr ProcessGroupNCCL::collective( } template -std::shared_ptr ProcessGroupNCCL::pointToPoint( +c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( std::vector& tensors, Fn fn, int peer, @@ -1186,7 +1186,7 @@ std::shared_ptr ProcessGroupNCCL::pointToPoint( } template -std::shared_ptr ProcessGroupNCCL::collective( +c10::intrusive_ptr ProcessGroupNCCL::collective( std::vector& inputs, std::vector& outputs, Fn fn, @@ -1203,7 +1203,7 @@ std::shared_ptr ProcessGroupNCCL::collective( } template -std::shared_ptr ProcessGroupNCCL::pointToPoint( +c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( std::vector& tensor, Fn fn, int peer, @@ -1217,7 +1217,7 @@ std::shared_ptr ProcessGroupNCCL::pointToPoint( [](std::vector&) {}); } -std::shared_ptr ProcessGroupNCCL::allreduce( +c10::intrusive_ptr ProcessGroupNCCL::allreduce( std::vector& tensors, const AllreduceOptions& opts) { check_gpu_tensors(tensors); @@ -1242,14 +1242,14 @@ std::shared_ptr ProcessGroupNCCL::allreduce( "nccl:all_reduce"); } -std::shared_ptr ProcessGroupNCCL::allreduce_coalesced( +c10::intrusive_ptr ProcessGroupNCCL::allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts) { throw std::runtime_error( "allreduce_coalesced is currently not supported with NCCL"); } -std::shared_ptr ProcessGroupNCCL::broadcast( +c10::intrusive_ptr ProcessGroupNCCL::broadcast( std::vector& tensors, const BroadcastOptions& opts) { check_gpu_tensors(tensors); @@ -1274,7 +1274,7 @@ std::shared_ptr ProcessGroupNCCL::broadcast( "nccl:broadcast"); } -std::shared_ptr ProcessGroupNCCL::reduce( +c10::intrusive_ptr ProcessGroupNCCL::reduce( std::vector& tensors, const ReduceOptions& opts) { check_gpu_tensors(tensors); @@ -1301,7 +1301,7 @@ std::shared_ptr ProcessGroupNCCL::reduce( "nccl:reduce"); } -std::shared_ptr ProcessGroupNCCL::allgather( +c10::intrusive_ptr ProcessGroupNCCL::allgather( std::vector>& outputTensors, std::vector& inputTensors, const AllgatherOptions& opts) { @@ -1346,7 +1346,7 @@ std::shared_ptr ProcessGroupNCCL::allgather( "nccl:all_gather"); } -std::shared_ptr ProcessGroupNCCL::allgather_coalesced( +c10::intrusive_ptr ProcessGroupNCCL::allgather_coalesced( std::vector>& /* unused */, std::vector& /* unused */, const AllgatherOptions& /* unused */) { @@ -1354,7 +1354,7 @@ std::shared_ptr ProcessGroupNCCL::allgather_coalesced( "ProcessGroupNCCL does not support allgather_coalesced"); } -std::shared_ptr ProcessGroupNCCL::reduce_scatter( +c10::intrusive_ptr ProcessGroupNCCL::reduce_scatter( std::vector& outputTensors, std::vector>& inputTensors, const ReduceScatterOptions& opts) { @@ -1400,7 +1400,7 @@ std::shared_ptr ProcessGroupNCCL::reduce_scatter( "nccl:reduce_scatter"); } -std::shared_ptr ProcessGroupNCCL::barrier( +c10::intrusive_ptr ProcessGroupNCCL::barrier( const BarrierOptions& opts) { std::vector devices; if (usedDeviceIdxs_.empty()) { @@ -1441,7 +1441,7 @@ std::shared_ptr ProcessGroupNCCL::barrier( } #ifdef ENABLE_NCCL_P2P_SUPPORT -std::shared_ptr ProcessGroupNCCL::alltoall_base( +c10::intrusive_ptr ProcessGroupNCCL::alltoall_base( at::Tensor& outputTensor, at::Tensor& inputTensor, std::vector& outputSplitSizes, @@ -1512,7 +1512,7 @@ std::shared_ptr ProcessGroupNCCL::alltoall_base( } } -std::shared_ptr ProcessGroupNCCL::send( +c10::intrusive_ptr ProcessGroupNCCL::send( std::vector& tensors, int dstRank, int /* unused */) { @@ -1531,7 +1531,7 @@ std::shared_ptr ProcessGroupNCCL::send( return ret; } -std::shared_ptr ProcessGroupNCCL::recv( +c10::intrusive_ptr ProcessGroupNCCL::recv( std::vector& tensors, int srcRank, int /* unused */) { @@ -1550,7 +1550,7 @@ std::shared_ptr ProcessGroupNCCL::recv( return ret; } #else -std::shared_ptr ProcessGroupNCCL::alltoall_base( +c10::intrusive_ptr ProcessGroupNCCL::alltoall_base( at::Tensor& /* unused */, at::Tensor& /* unused */, std::vector& /* unused */, @@ -1560,7 +1560,7 @@ std::shared_ptr ProcessGroupNCCL::alltoall_base( "ProcessGroupNCCL only supports alltoall* for NCCL lib version >= 2.7.0"); } -std::shared_ptr ProcessGroupNCCL::send( +c10::intrusive_ptr ProcessGroupNCCL::send( std::vector& /* unused */, int /* unused */, int /* unused */) { @@ -1568,7 +1568,7 @@ std::shared_ptr ProcessGroupNCCL::send( "ProcessGroupNCCL only supports send for NCCL lib version >= 2.7.0"); } -std::shared_ptr ProcessGroupNCCL::recv( +c10::intrusive_ptr ProcessGroupNCCL::recv( std::vector& /* unused */, int /* unused */, int /* unused */) { @@ -1591,34 +1591,34 @@ void ProcessGroupNCCL::groupEnd() { --ncclActiveGroupCounter_; } -std::shared_ptr ProcessGroupNCCL::alltoall( +c10::intrusive_ptr ProcessGroupNCCL::alltoall( std::vector& /* unused */, std::vector& /* unused */, const AllToAllOptions& /* unused */) { throw std::runtime_error("ProcessGroupNCCL does not support alltoall"); } -std::shared_ptr ProcessGroupNCCL::gather( +c10::intrusive_ptr ProcessGroupNCCL::gather( std::vector>& /* unused */, std::vector& /* unused */, const GatherOptions& /* unused */) { throw std::runtime_error("ProcessGroupNCCL does not support gather"); } -std::shared_ptr ProcessGroupNCCL::scatter( +c10::intrusive_ptr ProcessGroupNCCL::scatter( std::vector& /* unused */, std::vector>& /* unused */, const ScatterOptions& /* unused */) { throw std::runtime_error("ProcessGroupNCCL does not support scatter"); } -std::shared_ptr ProcessGroupNCCL::recvAnysource( +c10::intrusive_ptr ProcessGroupNCCL::recvAnysource( std::vector& /* unused */, int /* unused */) { throw std::runtime_error("ProcessGroupNCCL does not support recvAnysource"); } -std::shared_ptr ProcessGroupNCCL::allgather_base( +c10::intrusive_ptr ProcessGroupNCCL::allgather_base( at::Tensor& /*unused */, at::Tensor& /*unused */, const AllgatherOptions& /*unused */) { diff --git a/torch/lib/c10d/ProcessGroupNCCL.hpp b/torch/lib/c10d/ProcessGroupNCCL.hpp index 1520604629f2..59f06fda1ec1 100644 --- a/torch/lib/c10d/ProcessGroupNCCL.hpp +++ b/torch/lib/c10d/ProcessGroupNCCL.hpp @@ -65,7 +65,7 @@ constexpr const char* NCCL_ASYNC_ERROR_HANDLING = "NCCL_ASYNC_ERROR_HANDLING"; class ProcessGroupNCCL : public ProcessGroup { public: class WorkNCCL : public ProcessGroup::Work, - public std::enable_shared_from_this { + public std::enable_shared_from_this { public: // Constructor takes a list of CUDA devices WorkNCCL(const std::vector& devices, int rank, OpType opType, const char* profilingTitle = nullptr); @@ -411,64 +411,64 @@ class ProcessGroupNCCL : public ProcessGroup { virtual ~ProcessGroupNCCL(); - std::shared_ptr broadcast( + c10::intrusive_ptr broadcast( std::vector& tensors, const BroadcastOptions& opts = BroadcastOptions()) override; - std::shared_ptr allreduce( + c10::intrusive_ptr allreduce( std::vector& tensors, const AllreduceOptions& opts = AllreduceOptions()) override; - std::shared_ptr allreduce_coalesced( + c10::intrusive_ptr allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts = AllreduceCoalescedOptions()) override; - std::shared_ptr reduce( + c10::intrusive_ptr reduce( std::vector& tensors, const ReduceOptions& opts = ReduceOptions()) override; - std::shared_ptr allgather( + c10::intrusive_ptr allgather( std::vector>& outputTensors, std::vector& inputTensors, const AllgatherOptions& opts = AllgatherOptions()) override; - std::shared_ptr allgather_base( + c10::intrusive_ptr allgather_base( at::Tensor& outputbuffer, at::Tensor& inputbuffer, const AllgatherOptions& opts = AllgatherOptions()) override; - std::shared_ptr allgather_coalesced( + c10::intrusive_ptr allgather_coalesced( std::vector>& outputTensorLists, std::vector& inputTensors, const AllgatherOptions& opts = AllgatherOptions()) override; - std::shared_ptr reduce_scatter( + c10::intrusive_ptr reduce_scatter( std::vector& outputTensors, std::vector>& inputTensors, const ReduceScatterOptions& opts = ReduceScatterOptions()) override; - std::shared_ptr barrier( + c10::intrusive_ptr barrier( const BarrierOptions& opts = BarrierOptions()) override; - std::shared_ptr alltoall_base( + c10::intrusive_ptr alltoall_base( at::Tensor& outputTensor, at::Tensor& inputTensor, std::vector& outputSplitSizes, std::vector& inputSplitSizes, const AllToAllOptions& opts = AllToAllOptions()) override; - std::shared_ptr alltoall( + c10::intrusive_ptr alltoall( std::vector& outputTensors, std::vector& inputTensors, const AllToAllOptions& opts = AllToAllOptions()) override; - std::shared_ptr send( + c10::intrusive_ptr send( std::vector& tensors, int dstRank, int tag) override; - std::shared_ptr recv( + c10::intrusive_ptr recv( std::vector& tensors, int srcRank, int tag) override; @@ -478,17 +478,17 @@ class ProcessGroupNCCL : public ProcessGroup { static void groupEnd(); // Unsupported Ops - std::shared_ptr gather( + c10::intrusive_ptr gather( std::vector>& outputTensors, std::vector& inputTensors, const GatherOptions& opts = GatherOptions()) override; - std::shared_ptr scatter( + c10::intrusive_ptr scatter( std::vector& outputTensors, std::vector>& inputTensors, const ScatterOptions& opts = ScatterOptions()) override; - std::shared_ptr recvAnysource( + c10::intrusive_ptr recvAnysource( std::vector& tensors, int tag) override; @@ -515,7 +515,7 @@ class ProcessGroupNCCL : public ProcessGroup { virtual std::exception_ptr checkForNCCLErrors( const std::vector>& ncclComms); - virtual std::shared_ptr initWork( + virtual c10::intrusive_ptr initWork( std::vector devices, int rank, OpType opType, @@ -529,14 +529,14 @@ class ProcessGroupNCCL : public ProcessGroup { // ncclComm_t, at::cuda::CUDAStream&); // void {pre,post}(std::vector); template - std::shared_ptr collective( + c10::intrusive_ptr collective( std::vector& input, std::vector& output, Fn fn, OpType opType, const char* profilingTitle = nullptr); template - std::shared_ptr collective( + c10::intrusive_ptr collective( std::vector& input, std::vector& output, Fn fn, @@ -549,13 +549,13 @@ class ProcessGroupNCCL : public ProcessGroup { // primitives. It is the same structure as the helper used for collective // communicaiton primitives. template - std::shared_ptr pointToPoint( + c10::intrusive_ptr pointToPoint( std::vector& tensor, Fn fn, int peer, OpType opType); template - std::shared_ptr pointToPoint( + c10::intrusive_ptr pointToPoint( std::vector& tensor, Fn fn, int peer, @@ -664,7 +664,7 @@ class ProcessGroupNCCL : public ProcessGroup { std::list workMetaList_; // Add Work Pointer to workVector - void workEnqueue(std::shared_ptr); + void workEnqueue(c10::intrusive_ptr); // The CUDA steams used by NCCL kernels std::unordered_map> diff --git a/torch/lib/c10d/ProcessGroupRoundRobin.cpp b/torch/lib/c10d/ProcessGroupRoundRobin.cpp index 032f63c320f5..c77188577a62 100644 --- a/torch/lib/c10d/ProcessGroupRoundRobin.cpp +++ b/torch/lib/c10d/ProcessGroupRoundRobin.cpp @@ -17,66 +17,66 @@ ProcessGroupRoundRobin::ProcessGroupRoundRobin( ProcessGroupRoundRobin::~ProcessGroupRoundRobin() {} -std::shared_ptr ProcessGroupRoundRobin::broadcast( +c10::intrusive_ptr ProcessGroupRoundRobin::broadcast( std::vector& tensors, const BroadcastOptions& opts) { return next()->broadcast(tensors, opts); } -std::shared_ptr ProcessGroupRoundRobin::allreduce( +c10::intrusive_ptr ProcessGroupRoundRobin::allreduce( std::vector& tensors, const AllreduceOptions& opts) { return next()->allreduce(tensors, opts); } -std::shared_ptr ProcessGroupRoundRobin::allreduce_coalesced( +c10::intrusive_ptr ProcessGroupRoundRobin::allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts) { return next()->allreduce_coalesced(tensors, opts); } -std::shared_ptr ProcessGroupRoundRobin::reduce( +c10::intrusive_ptr ProcessGroupRoundRobin::reduce( std::vector& tensors, const ReduceOptions& opts) { return next()->reduce(tensors, opts); } -std::shared_ptr ProcessGroupRoundRobin::allgather( +c10::intrusive_ptr ProcessGroupRoundRobin::allgather( std::vector>& outputs, std::vector& inputs, const AllgatherOptions& opts) { return next()->allgather(outputs, inputs, opts); }; -std::shared_ptr ProcessGroupRoundRobin::allgather_coalesced( +c10::intrusive_ptr ProcessGroupRoundRobin::allgather_coalesced( std::vector>& outputTensorLists, std::vector& inputTensors, const AllgatherOptions& opts) { return next()->allgather(outputTensorLists, inputTensors, opts); } -std::shared_ptr ProcessGroupRoundRobin::gather( +c10::intrusive_ptr ProcessGroupRoundRobin::gather( std::vector>& outputs, std::vector& inputs, const GatherOptions& opts) { return next()->gather(outputs, inputs, opts); }; -std::shared_ptr ProcessGroupRoundRobin::scatter( +c10::intrusive_ptr ProcessGroupRoundRobin::scatter( std::vector& outputs, std::vector>& inputs, const ScatterOptions& opts) { return next()->scatter(outputs, inputs, opts); }; -std::shared_ptr ProcessGroupRoundRobin::reduce_scatter( +c10::intrusive_ptr ProcessGroupRoundRobin::reduce_scatter( std::vector& outputs, std::vector>& inputs, const ReduceScatterOptions& opts) { return next()->reduce_scatter(outputs, inputs, opts); }; -std::shared_ptr ProcessGroupRoundRobin::alltoall_base( +c10::intrusive_ptr ProcessGroupRoundRobin::alltoall_base( at::Tensor& outputTensor, at::Tensor& inputTensor, std::vector& outputSplitSizes, @@ -86,27 +86,27 @@ std::shared_ptr ProcessGroupRoundRobin::alltoall_base( outputTensor, inputTensor, outputSplitSizes, inputSplitSizes, opts); }; -std::shared_ptr ProcessGroupRoundRobin::send( +c10::intrusive_ptr ProcessGroupRoundRobin::send( std::vector& /* unused */, int /* unused */, int /* unused */) { throw std::runtime_error("ProcessGroupRoundRobin does not support send"); }; -std::shared_ptr ProcessGroupRoundRobin::recv( +c10::intrusive_ptr ProcessGroupRoundRobin::recv( std::vector& /* unused */, int /* unused */, int /* unused */) { throw std::runtime_error("ProcessGroupRoundRobin does not support recv"); }; -std::shared_ptr ProcessGroupRoundRobin::recvAnysource( +c10::intrusive_ptr ProcessGroupRoundRobin::recvAnysource( std::vector& /* unused */, int /* unused */) { throw std::runtime_error("ProcessGroupRoundRobin does not support recv"); }; -std::shared_ptr ProcessGroupRoundRobin::barrier( +c10::intrusive_ptr ProcessGroupRoundRobin::barrier( const BarrierOptions& /* unused */) { throw std::runtime_error("ProcessGroupRoundRobin does not support barrier"); }; @@ -120,7 +120,7 @@ const std::shared_ptr& ProcessGroupRoundRobin::next() { return processGroup; } -std::shared_ptr ProcessGroupRoundRobin::allgather_base( +c10::intrusive_ptr ProcessGroupRoundRobin::allgather_base( at::Tensor& /*unused */, at::Tensor& /*unused */, const AllgatherOptions& /*unused */) { diff --git a/torch/lib/c10d/ProcessGroupRoundRobin.hpp b/torch/lib/c10d/ProcessGroupRoundRobin.hpp index bbbd0a1c756b..62d59ef18ce5 100644 --- a/torch/lib/c10d/ProcessGroupRoundRobin.hpp +++ b/torch/lib/c10d/ProcessGroupRoundRobin.hpp @@ -25,75 +25,75 @@ class ProcessGroupRoundRobin final : public ProcessGroup { ~ProcessGroupRoundRobin() override; - std::shared_ptr broadcast( + c10::intrusive_ptr broadcast( std::vector& tensors, const BroadcastOptions& opts = BroadcastOptions()) override; - std::shared_ptr allreduce( + c10::intrusive_ptr allreduce( std::vector& tensors, const AllreduceOptions& opts = AllreduceOptions()) override; - std::shared_ptr allreduce_coalesced( + c10::intrusive_ptr allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts = AllreduceCoalescedOptions()) override; - std::shared_ptr reduce( + c10::intrusive_ptr reduce( std::vector& tensors, const ReduceOptions& opts = ReduceOptions()) override; - std::shared_ptr allgather( + c10::intrusive_ptr allgather( std::vector>& outputs, std::vector& inputs, const AllgatherOptions& opts = AllgatherOptions()) override; - std::shared_ptr allgather_base( + c10::intrusive_ptr allgather_base( at::Tensor& outputBuffer, at::Tensor& inputBuffer, const AllgatherOptions& opts = AllgatherOptions()) override; - std::shared_ptr allgather_coalesced( + c10::intrusive_ptr allgather_coalesced( std::vector>& outputTensorLists, std::vector& inputTensors, const AllgatherOptions& opts = AllgatherOptions()) override; - std::shared_ptr gather( + c10::intrusive_ptr gather( std::vector>& outputs, std::vector& inputs, const GatherOptions& opts = GatherOptions()) override; - std::shared_ptr scatter( + c10::intrusive_ptr scatter( std::vector& outputs, std::vector>& inputs, const ScatterOptions& opts = ScatterOptions()) override; - std::shared_ptr reduce_scatter( + c10::intrusive_ptr reduce_scatter( std::vector& outputs, std::vector>& inputs, const ReduceScatterOptions& opts = ReduceScatterOptions()) override; - std::shared_ptr alltoall_base( + c10::intrusive_ptr alltoall_base( at::Tensor& outputTensor, at::Tensor& inputTensor, std::vector& outputSplitSizes, std::vector& inputSplitSizes, const AllToAllOptions& opts = AllToAllOptions()) override; - std::shared_ptr send( + c10::intrusive_ptr send( std::vector& tensors, int dstRank, int tag) override; - std::shared_ptr recv( + c10::intrusive_ptr recv( std::vector& tensors, int srcRank, int tag) override; - std::shared_ptr recvAnysource( + c10::intrusive_ptr recvAnysource( std::vector& tensors, int tag) override; - std::shared_ptr barrier( + c10::intrusive_ptr barrier( const BarrierOptions& opts = BarrierOptions()) override; private: diff --git a/torch/lib/c10d/comm.cpp b/torch/lib/c10d/comm.cpp index a8628e0c942e..5ef88f058aca 100644 --- a/torch/lib/c10d/comm.cpp +++ b/torch/lib/c10d/comm.cpp @@ -45,8 +45,10 @@ class BroadcastWork { // because c10d::ProcessGroup::broadcast takes a vector argument. std::vector flat_tensor_; + private: + // The broadcast work that is kicked off upon construction. - std::shared_ptr work_; + c10::intrusive_ptr work_; }; } // namespace diff --git a/torch/lib/c10d/example/allreduce.cpp b/torch/lib/c10d/example/allreduce.cpp index 76d6a5588f7e..3de7447d092a 100644 --- a/torch/lib/c10d/example/allreduce.cpp +++ b/torch/lib/c10d/example/allreduce.cpp @@ -19,7 +19,7 @@ int main(int argc, char** argv) { } // Kick off work - std::vector> pending; + std::vector> pending; for (auto i = 0; i < ntensors; i++) { std::vector tmp = {tensors[i]}; pending.push_back(pg.allreduce(tmp)); diff --git a/torch/lib/c10d/reducer.cpp b/torch/lib/c10d/reducer.cpp index c05ce685bb7d..c5ee54a9ee8e 100644 --- a/torch/lib/c10d/reducer.cpp +++ b/torch/lib/c10d/reducer.cpp @@ -472,7 +472,7 @@ std::vector> Reducer::get_bucket_tensors() const { } void Reducer::set_forward_pass_work_handle( - std::shared_ptr forwardPassWorkHandle, + c10::intrusive_ptr forwardPassWorkHandle, bool useStaticWorldSize) { std::lock_guard lock(mutex_); forwardPassWorkHandle_.workHandle = std::move(forwardPassWorkHandle); diff --git a/torch/lib/c10d/reducer.hpp b/torch/lib/c10d/reducer.hpp index 4874f0dd8703..e0fe0004f88e 100644 --- a/torch/lib/c10d/reducer.hpp +++ b/torch/lib/c10d/reducer.hpp @@ -8,6 +8,7 @@ #include #include +#include #include #include #include @@ -96,7 +97,7 @@ class Reducer { // Creates and sets ForwardPassWorkHandle given a ProcessGroup::Work and the // corresponding tensor being reduced. void set_forward_pass_work_handle( - std::shared_ptr forwardPassWorkHandle, + c10::intrusive_ptr forwardPassWorkHandle, bool useStaticWorldSize); // Retrieve on-device tensors used to track locally unused parameters. For @@ -158,7 +159,7 @@ class Reducer { bool local_used_maps_reduced_; // Work handle for allreduce on local_used_maps_ - std::shared_ptr local_used_work_; + c10::intrusive_ptr local_used_work_; void verify_replicas_within_process(); @@ -282,7 +283,7 @@ class Reducer { size_t pending; // Keep work handle around when this set of buckets is being reduced. - std::shared_ptr work; + c10::intrusive_ptr work; // Keep future work handle around if DDP comm hook is registered. c10::intrusive_ptr future_work; @@ -340,7 +341,7 @@ class Reducer { // A struct containing work handle and tensor for allreduce scheduled in // forward pass, if applicable. struct ForwardPassAllreduceWork { - std::shared_ptr workHandle; + c10::intrusive_ptr workHandle; at::Tensor resultTensor; // whether we should divide by the initial world_size or the no. of // remaining DDP ranks. diff --git a/torch/lib/c10d/test/ProcessGroupGlooAsyncTest.cpp b/torch/lib/c10d/test/ProcessGroupGlooAsyncTest.cpp index 92dede9a573e..1363a842eab3 100644 --- a/torch/lib/c10d/test/ProcessGroupGlooAsyncTest.cpp +++ b/torch/lib/c10d/test/ProcessGroupGlooAsyncTest.cpp @@ -93,7 +93,7 @@ class AsyncInputIsOutputTest : public AsyncTest { } } - void wait(std::shared_ptr& work) { + void wait(c10::intrusive_ptr& work) { at::cuda::CUDAMultiStreamGuard guard(streams_); work->wait(); } @@ -130,7 +130,7 @@ class AsyncAllreduceTest : public AsyncInputIsOutputTest { AsyncAllreduceTest(const std::string& path, int numTensors) : AsyncInputIsOutputTest(path, numTensors) {} - std::shared_ptr run() { + c10::intrusive_ptr run() { // For the duration of this function, make THC use our streams at::cuda::CUDAMultiStreamGuard guard(streams_); @@ -156,7 +156,7 @@ class AsyncBroadcastTest : public AsyncInputIsOutputTest { AsyncBroadcastTest(const std::string& path, int numTensors) : AsyncInputIsOutputTest(path, numTensors) {} - std::shared_ptr run(int rootRank, int rootTensor) { + c10::intrusive_ptr run(int rootRank, int rootTensor) { // For the duration of this function, make THC use our streams at::cuda::CUDAMultiStreamGuard guard(streams_); @@ -185,7 +185,7 @@ void runAsyncAllreduceTest( size_t numProcesses = 4, size_t numTensors = 2) { auto tests = initialize(path, numProcesses, numTensors); - std::vector> work(numProcesses); + std::vector> work(numProcesses); for (size_t i = 0; i < numProcesses; i++) { work[i] = tests[i].run(); } @@ -229,7 +229,7 @@ void runAsyncBroadcastTest( // Try every permutation of root rank and root tensor for (size_t rootRank = 0; rootRank < numProcesses; rootRank++) { for (size_t rootTensor = 0; rootTensor < numTensors; rootTensor++) { - std::vector> work(numProcesses); + std::vector> work(numProcesses); for (size_t i = 0; i < numProcesses; i++) { work[i] = tests[i].run(rootRank, rootTensor); } diff --git a/torch/lib/c10d/test/ProcessGroupGlooTest.cpp b/torch/lib/c10d/test/ProcessGroupGlooTest.cpp index da4f9b5fc106..de993a1110b4 100644 --- a/torch/lib/c10d/test/ProcessGroupGlooTest.cpp +++ b/torch/lib/c10d/test/ProcessGroupGlooTest.cpp @@ -44,7 +44,7 @@ class SignalTest { }); } - std::shared_ptr<::c10d::ProcessGroup::Work> run(int rank, int size) { + c10::intrusive_ptr<::c10d::ProcessGroup::Work> run(int rank, int size) { auto store = std::make_shared<::c10d::FileStore>(path_, size); ::c10d::ProcessGroupGloo::Options options; @@ -62,7 +62,7 @@ class SignalTest { }; // Loop until an exception happens - std::shared_ptr<::c10d::ProcessGroup::Work> work; + c10::intrusive_ptr<::c10d::ProcessGroup::Work> work; while (true) { work = pg.allreduce(tensors); try { @@ -82,7 +82,7 @@ class SignalTest { Semaphore sem_; }; -std::shared_ptr<::c10d::ProcessGroup::Work> testSignal( +c10::intrusive_ptr<::c10d::ProcessGroup::Work> testSignal( const std::string& path, int signal) { Fork fork; @@ -107,7 +107,7 @@ class ProcessGroupGlooDelayed : public ::c10d::ProcessGroupGloo { Options options) : ProcessGroupGloo(store, rank, size, options) {} - std::shared_ptr<::c10d::ProcessGroup::Work> send( + c10::intrusive_ptr<::c10d::ProcessGroup::Work> send( std::vector& tensors, int dstRank, int tag) override { @@ -200,7 +200,7 @@ void testAllreduce(const std::string& path, const at::DeviceType b) { } // Kick off work - std::vector> work(size); + std::vector> work(size); for (auto i = 0; i < size; i++) { work[i] = tests[i].getProcessGroup().allreduce(inputs[i]); } @@ -250,7 +250,7 @@ void testBroadcast(const std::string& path, const at::DeviceType b) { options.rootTensor = j; // Kick off work - std::vector> work(size); + std::vector> work(size); for (auto i = 0; i < size; i++) { work[i] = tests[i].getProcessGroup().broadcast(inputs[i], options); } @@ -316,7 +316,7 @@ void testAlltoall(const std::string& path, const at::DeviceType b) { }; // Kick off work - std::vector> work(size); + std::vector> work(size); for (auto rank = 0; rank < size; rank++) { work[rank] = tests[rank].getProcessGroup().alltoall_base( outputs[rank], inputs[rank], outputSplits[rank], inputSplits[rank]); @@ -349,7 +349,7 @@ void testBarrier(const std::string& path) { auto tests = CollectiveTest::initialize(path, size); // Kick off work - std::vector> work(size); + std::vector> work(size); for (auto i = 0; i < size; i++) { work[i] = tests[i].getProcessGroup().barrier(); } diff --git a/torch/lib/c10d/test/ProcessGroupMPITest.cpp b/torch/lib/c10d/test/ProcessGroupMPITest.cpp index 3f5a9e4cf331..6c60b3d6742d 100644 --- a/torch/lib/c10d/test/ProcessGroupMPITest.cpp +++ b/torch/lib/c10d/test/ProcessGroupMPITest.cpp @@ -14,7 +14,7 @@ // Wait for work to complete void waitWork( std::shared_ptr pg, - std::vector> works) { + std::vector> works) { for (auto& work : works) { try { work->wait(); @@ -34,10 +34,11 @@ void testAllreduce(int iter = 1000) { allTensors[i] = std::vector({tensor}); } - std::vector> works; + std::vector> works; for (auto& tensors : allTensors) { // Kick off work - std::shared_ptr<::c10d::ProcessGroup::Work> work = pg->allreduce(tensors); + c10::intrusive_ptr<::c10d::ProcessGroup::Work> work = + pg->allreduce(tensors); works.push_back(std::move(work)); } @@ -73,10 +74,11 @@ void testBroadcast(int iter = 10000) { } } - std::vector> works; + std::vector> works; for (auto& tensors : allTensors) { // Kick off work - std::shared_ptr<::c10d::ProcessGroup::Work> work = pg->broadcast(tensors); + c10::intrusive_ptr<::c10d::ProcessGroup::Work> work = + pg->broadcast(tensors); works.push_back(std::move(work)); } @@ -104,10 +106,10 @@ void testReduce(int iter = 10000) { allTensors[i] = std::vector({tensor}); } - std::vector> works; + std::vector> works; for (auto& tensors : allTensors) { // Kick off work - std::shared_ptr<::c10d::ProcessGroup::Work> work = pg->reduce(tensors); + c10::intrusive_ptr<::c10d::ProcessGroup::Work> work = pg->reduce(tensors); works.push_back(std::move(work)); } @@ -150,10 +152,10 @@ void testAllgather(int iter = 10000) { } } - std::vector> works; + std::vector> works; for (size_t i = 0; i < allTensors.size(); ++i) { // Kick off work - std::shared_ptr<::c10d::ProcessGroup::Work> work = + c10::intrusive_ptr<::c10d::ProcessGroup::Work> work = pg->allgather(allOutputTensors[i], allTensors[i]); works.push_back(std::move(work)); } @@ -198,10 +200,10 @@ void testGather(int iter = 10000) { } } - std::vector> works; + std::vector> works; for (size_t i = 0; i < allTensors.size(); ++i) { // Kick off work - std::shared_ptr<::c10d::ProcessGroup::Work> work = + c10::intrusive_ptr<::c10d::ProcessGroup::Work> work = pg->gather(allOutputTensors[i], allTensors[i]); works.push_back(std::move(work)); } @@ -249,10 +251,10 @@ void testScatter(int iter = 1) { } } - std::vector> works; + std::vector> works; for (size_t i = 0; i < allTensors.size(); ++i) { // Kick off work - std::shared_ptr<::c10d::ProcessGroup::Work> work = + c10::intrusive_ptr<::c10d::ProcessGroup::Work> work = pg->scatter(allTensors[i], allInputTensors[i]); works.push_back(std::move(work)); } @@ -289,27 +291,27 @@ void testSendRecv(bool recvAnysource, int iter = 10000) { } if (rank == 0) { - std::vector> works; + std::vector> works; for (auto& tensors : allTensors) { // Kick off work - std::shared_ptr<::c10d::ProcessGroup::Work> work = + c10::intrusive_ptr<::c10d::ProcessGroup::Work> work = pg->send(tensors, 1, 0); works.push_back(std::move(work)); } waitWork(pg, works); } if (rank == 1) { - std::vector> works; + std::vector> works; std::vector srcRanks(allTensors.size(), -1); size_t i = 0; for (auto& tensors : allTensors) { // Kick off work if (!recvAnysource) { - std::shared_ptr<::c10d::ProcessGroup::Work> work = + c10::intrusive_ptr<::c10d::ProcessGroup::Work> work = pg->recv(tensors, 0, 0); works.push_back(std::move(work)); } else { - std::shared_ptr<::c10d::ProcessGroup::Work> work = + c10::intrusive_ptr<::c10d::ProcessGroup::Work> work = pg->recvAnysource(tensors, 0); works.push_back(std::move(work)); } diff --git a/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp b/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp index e906702a889d..f1348922e126 100644 --- a/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp +++ b/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp @@ -56,12 +56,12 @@ class ProcessGroupNCCLSimulateErrors : public c10d::ProcessGroupNCCL { ProcessGroupNCCLSimulateErrors::kWatchdogThreadSleepMillis); } - std::shared_ptr initWork( + c10::intrusive_ptr initWork( std::vector devices, int rank, c10d::OpType opType, const char* profilingTitle) override { - return std::make_shared( + return c10::make_intrusive( devices, simulate_error_, rank, opType); } @@ -113,12 +113,12 @@ class ProcessGroupNCCLTimedOutErrors : public ProcessGroupNCCLSimulateErrors { : ProcessGroupNCCLSimulateErrors(store, rank, size, opts), set_timedout_error_(false) {} - std::shared_ptr initWork( + c10::intrusive_ptr initWork( std::vector devices, int rank, c10d::OpType opType, const char* profilingTitle) override { - return std::make_shared( + return c10::make_intrusive( devices, set_timedout_error_, rank, opType); } diff --git a/torch/lib/c10d/test/ProcessGroupNCCLTest.cpp b/torch/lib/c10d/test/ProcessGroupNCCLTest.cpp index 92b477fae7de..efa96312aba0 100644 --- a/torch/lib/c10d/test/ProcessGroupNCCLTest.cpp +++ b/torch/lib/c10d/test/ProcessGroupNCCLTest.cpp @@ -80,7 +80,7 @@ class NCCLTest : public NCCLTestBase { } void wait( - std::shared_ptr& work, + c10::intrusive_ptr& work, std::chrono::milliseconds timeout = kNoTimeout) { at::cuda::CUDAMultiStreamGuard guard(streams_); work->wait(timeout); @@ -166,7 +166,7 @@ class AllreduceNCCLTest : public NCCLTest { AllreduceNCCLTest(const std::string& path, int worldSize) : NCCLTest(path, worldSize) {} - std::shared_ptr run() { + c10::intrusive_ptr run() { // For the duration of this function, make THC use our streams at::cuda::CUDAMultiStreamGuard guard(streams_); @@ -189,7 +189,7 @@ class BroadcastNCCLTest : public NCCLTest { BroadcastNCCLTest(const std::string& path, int worldSize) : NCCLTest(path, worldSize) {} - std::shared_ptr run(int rootRank, int rootTensor) { + c10::intrusive_ptr run(int rootRank, int rootTensor) { // For the duration of this function, make THC use our streams at::cuda::CUDAMultiStreamGuard guard(streams_); @@ -208,7 +208,7 @@ class ReduceNCCLTest : public NCCLTest { ReduceNCCLTest(const std::string& path, int worldSize) : NCCLTest(path, worldSize) {} - std::shared_ptr run(int rootRank, int rootTensor) { + c10::intrusive_ptr run(int rootRank, int rootTensor) { // For the duration of this function, make THC use our streams at::cuda::CUDAMultiStreamGuard guard(streams_); @@ -227,7 +227,7 @@ class AllgatherNCCLTest : public NCCLTest { AllgatherNCCLTest(const std::string& path, int worldSize) : NCCLTest(path, worldSize) {} - std::shared_ptr run() { + c10::intrusive_ptr run() { // For the duration of this function, make THC use our streams at::cuda::CUDAMultiStreamGuard guard(streams_); @@ -242,7 +242,7 @@ struct ReduceScatterNCCLTest : NCCLTest { ReduceScatterNCCLTest(const std::string& path, int worldSize) : NCCLTest(path, worldSize) {} - std::shared_ptr run() { + c10::intrusive_ptr run() { // For the duration of this function, make THC use our streams at::cuda::CUDAMultiStreamGuard guard(streams_); From 0cfe3451d485f9e1ac2828cc2dd53d331f61ce24 Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Tue, 10 Nov 2020 23:27:21 -0800 Subject: [PATCH 02/93] [c10d] switch Store to be managed by intrusive_ptr (#47074) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47074 Test Plan: Imported from OSS Reviewed By: mrshenli Differential Revision: D24667128 Pulled By: wanchaol fbshipit-source-id: 9b6024c31c851b7c3243540f460ae57323da523b --- test/cpp/rpc/e2e_test_base.h | 4 ++-- test/cpp_extensions/cpp_c10d_extension.cpp | 2 +- test/cpp_extensions/cpp_c10d_extension.hpp | 2 +- torch/csrc/distributed/c10d/init.cpp | 22 +++++++++---------- torch/csrc/distributed/rpc/init.cpp | 2 +- .../csrc/distributed/rpc/tensorpipe_agent.cpp | 2 +- torch/csrc/distributed/rpc/tensorpipe_agent.h | 2 +- torch/lib/c10d/PrefixStore.cpp | 2 +- torch/lib/c10d/PrefixStore.hpp | 6 +++-- torch/lib/c10d/ProcessGroupGloo.cpp | 6 ++--- torch/lib/c10d/ProcessGroupGloo.hpp | 2 +- torch/lib/c10d/ProcessGroupNCCL.cpp | 2 +- torch/lib/c10d/ProcessGroupNCCL.hpp | 8 +++---- torch/lib/c10d/Store.hpp | 4 +++- torch/lib/c10d/frontend.hpp | 4 ++-- torch/lib/c10d/test/FileStoreTest.cpp | 9 ++++---- torch/lib/c10d/test/HashStoreTest.cpp | 6 ++--- .../c10d/test/ProcessGroupGlooAsyncTest.cpp | 2 +- torch/lib/c10d/test/ProcessGroupGlooTest.cpp | 6 ++--- .../c10d/test/ProcessGroupNCCLErrorsTest.cpp | 8 +++---- torch/lib/c10d/test/ProcessGroupNCCLTest.cpp | 2 +- torch/lib/c10d/test/TCPStoreTest.cpp | 14 ++++++------ 22 files changed, 61 insertions(+), 56 deletions(-) diff --git a/test/cpp/rpc/e2e_test_base.h b/test/cpp/rpc/e2e_test_base.h index 9d3ab71c0cfc..114284839858 100644 --- a/test/cpp/rpc/e2e_test_base.h +++ b/test/cpp/rpc/e2e_test_base.h @@ -28,7 +28,7 @@ class TestE2EBase : public ::testing::Test { autogradContainer = getDistAutogradContainer(); // Setup server store. - store = std::make_shared( + store = c10::make_intrusive( serverAddress, 0, numWorkers, true, std::chrono::seconds(10)); buildRpcAgent(); @@ -147,7 +147,7 @@ class TestE2EBase : public ::testing::Test { std::shared_ptr rpcAgent; static const size_t numIters; static const size_t numWorkers; - std::shared_ptr store; + c10::intrusive_ptr store; static const char* serverAddress; }; diff --git a/test/cpp_extensions/cpp_c10d_extension.cpp b/test/cpp_extensions/cpp_c10d_extension.cpp index 50e5f5861caa..d5ba55a6379c 100644 --- a/test/cpp_extensions/cpp_c10d_extension.cpp +++ b/test/cpp_extensions/cpp_c10d_extension.cpp @@ -108,7 +108,7 @@ c10::intrusive_ptr ProcessGroupTest::recvAnysource( } std::shared_ptr ProcessGroupTest::createProcessGroupTest( - const std::shared_ptr<::c10d::Store>& store, + const c10::intrusive_ptr<::c10d::Store>& store, int rank, int size, const std::chrono::duration& timeout) { diff --git a/test/cpp_extensions/cpp_c10d_extension.hpp b/test/cpp_extensions/cpp_c10d_extension.hpp index 8aeec736d440..1773953629d5 100644 --- a/test/cpp_extensions/cpp_c10d_extension.hpp +++ b/test/cpp_extensions/cpp_c10d_extension.hpp @@ -101,7 +101,7 @@ class ProcessGroupTest : public ProcessGroup { // Create a new ProcessGroupTest instance static std::shared_ptr createProcessGroupTest( - const std::shared_ptr<::c10d::Store>& store, + const c10::intrusive_ptr<::c10d::Store>& store, int rank, int size, const std::chrono::duration& timeout); diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 136efd32fc87..e9d8f618eb21 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -342,7 +342,7 @@ They are used in specifying strategies for reduction collectives, e.g., .def_readwrite("timeout", &::c10d::AllToAllOptions::timeout); auto store = - py::class_<::c10d::Store, std::shared_ptr<::c10d::Store>, PythonStore>( + py::class_<::c10d::Store, c10::intrusive_ptr<::c10d::Store>, PythonStore>( module, "Store", R"( @@ -546,7 +546,7 @@ Example:: >>> store.wait(["bad_key"], timedelta(seconds=10)) )"); - shared_ptr_class_<::c10d::FileStore>( + intrusive_ptr_class_<::c10d::FileStore>( module, "FileStore", store, @@ -569,7 +569,7 @@ Example:: .def(py::init()); #ifndef _WIN32 - shared_ptr_class_<::c10d::HashStore>( + intrusive_ptr_class_<::c10d::HashStore>( module, "HashStore", store, @@ -586,7 +586,7 @@ Example:: )") .def(py::init<>()); - shared_ptr_class_<::c10d::TCPStore>( + intrusive_ptr_class_<::c10d::TCPStore>( module, "TCPStore", store, @@ -626,7 +626,7 @@ Example:: std::chrono::milliseconds(::c10d::Store::kDefaultTimeout)); #endif - shared_ptr_class_<::c10d::PrefixStore>( + intrusive_ptr_class_<::c10d::PrefixStore>( module, "PrefixStore", store, @@ -639,7 +639,7 @@ that adds a prefix to each key inserted to the store. prefix (str): The prefix string that is prepended to each key before being inserted into the store. store (torch.distributed.store): A store object that forms the underlying key-value store. )") - .def(py::init>()); + .def(py::init>()); auto processGroup = shared_ptr_class_<::c10d::ProcessGroup>(module, "ProcessGroup") @@ -952,13 +952,13 @@ that adds a prefix to each key inserted to the store. processGroupGloo .def( py::init< - const std::shared_ptr<::c10d::Store>&, + const c10::intrusive_ptr<::c10d::Store>&, int, int, ::c10d::ProcessGroupGloo::Options>(), py::call_guard()) .def( - py::init([](const std::shared_ptr<::c10d::Store>& store, + py::init([](const c10::intrusive_ptr<::c10d::Store>& store, int rank, int size, std::chrono::milliseconds timeout) { @@ -997,13 +997,13 @@ that adds a prefix to each key inserted to the store. module, "ProcessGroupNCCL", processGroup) .def( py::init< - const std::shared_ptr<::c10d::Store>&, + const c10::intrusive_ptr<::c10d::Store>&, int, int, ::c10d::ProcessGroupNCCL::Options>(), py::call_guard()) .def( - py::init([](const std::shared_ptr<::c10d::Store>& store, + py::init([](const c10::intrusive_ptr<::c10d::Store>& store, int rank, int size, const std::chrono::milliseconds& timeout) { @@ -1168,7 +1168,7 @@ that adds a prefix to each key inserted to the store. // Python side of the world. Calling Python functions on a Python object // completely bypasses pybind11. We need to test that the overloaded // functions call into Python and behave like we expect. - [](std::shared_ptr<::c10d::Store> store) { + [](c10::intrusive_ptr<::c10d::Store> store) { auto add = [&store](const std::string& key, int64_t value) { store->add(key, value); }; diff --git a/torch/csrc/distributed/rpc/init.cpp b/torch/csrc/distributed/rpc/init.cpp index 1d82a619ed7e..81af4abebd5f 100644 --- a/torch/csrc/distributed/rpc/init.cpp +++ b/torch/csrc/distributed/rpc/init.cpp @@ -576,7 +576,7 @@ PyObject* rpc_init(PyObject* _unused, PyObject* noargs) { shared_ptr_class_(module, "TensorPipeAgent", rpcAgent) .def( - py::init([](const std::shared_ptr<::c10d::Store>& store, + py::init([](const c10::intrusive_ptr<::c10d::Store>& store, std::string selfName, worker_id_t selfId, int worldSize, diff --git a/torch/csrc/distributed/rpc/tensorpipe_agent.cpp b/torch/csrc/distributed/rpc/tensorpipe_agent.cpp index 6bf65f4c2628..eff1e7ebdf21 100644 --- a/torch/csrc/distributed/rpc/tensorpipe_agent.cpp +++ b/torch/csrc/distributed/rpc/tensorpipe_agent.cpp @@ -220,7 +220,7 @@ void TensorPipeAgent::collectNames() { } TensorPipeAgent::TensorPipeAgent( - const std::shared_ptr<::c10d::Store>& store, + const c10::intrusive_ptr<::c10d::Store>& store, std::string selfName, worker_id_t selfId, int worldSize, diff --git a/torch/csrc/distributed/rpc/tensorpipe_agent.h b/torch/csrc/distributed/rpc/tensorpipe_agent.h index b4a500de65be..b8c9a8c64e5c 100644 --- a/torch/csrc/distributed/rpc/tensorpipe_agent.h +++ b/torch/csrc/distributed/rpc/tensorpipe_agent.h @@ -141,7 +141,7 @@ struct AggregatedNetworkData { class TensorPipeAgent : public RpcAgent { public: TensorPipeAgent( - const std::shared_ptr<::c10d::Store>& store, + const c10::intrusive_ptr<::c10d::Store>& store, std::string selfName, worker_id_t selfId, int worldSize, diff --git a/torch/lib/c10d/PrefixStore.cpp b/torch/lib/c10d/PrefixStore.cpp index 5f9a3c9c21ec..6f71e422bd0e 100644 --- a/torch/lib/c10d/PrefixStore.cpp +++ b/torch/lib/c10d/PrefixStore.cpp @@ -4,7 +4,7 @@ namespace c10d { PrefixStore::PrefixStore( const std::string& prefix, - std::shared_ptr store) + c10::intrusive_ptr store) : prefix_(prefix), store_(store) {} std::string PrefixStore::joinKey(const std::string& key) { diff --git a/torch/lib/c10d/PrefixStore.hpp b/torch/lib/c10d/PrefixStore.hpp index cad7112fbd76..ec50b3b719bf 100644 --- a/torch/lib/c10d/PrefixStore.hpp +++ b/torch/lib/c10d/PrefixStore.hpp @@ -7,7 +7,9 @@ namespace c10d { class PrefixStore : public Store { public: - explicit PrefixStore(const std::string& prefix, std::shared_ptr store); + explicit PrefixStore( + const std::string& prefix, + c10::intrusive_ptr store); virtual ~PrefixStore(){}; @@ -31,7 +33,7 @@ class PrefixStore : public Store { protected: std::string prefix_; - std::shared_ptr store_; + c10::intrusive_ptr store_; std::string joinKey(const std::string& key); std::vector joinKeys(const std::vector& keys); diff --git a/torch/lib/c10d/ProcessGroupGloo.cpp b/torch/lib/c10d/ProcessGroupGloo.cpp index 90c9b695de28..22da878cce43 100644 --- a/torch/lib/c10d/ProcessGroupGloo.cpp +++ b/torch/lib/c10d/ProcessGroupGloo.cpp @@ -108,7 +108,7 @@ namespace { // Wrap c10d store as Gloo store class GlooStore : public ::gloo::rendezvous::Store { public: - GlooStore(const std::shared_ptr<::c10d::Store>& store) : store_(store) {} + GlooStore(const c10::intrusive_ptr<::c10d::Store>& store) : store_(store) {} void set(const std::string& key, const std::vector& value) override { std::vector tmp(value.begin(), value.end()); @@ -131,7 +131,7 @@ class GlooStore : public ::gloo::rendezvous::Store { } protected: - std::shared_ptr<::c10d::Store> store_; + c10::intrusive_ptr<::c10d::Store> store_; }; typedef void (*ReduceFunc)(void*, const void*, const void*, size_t); @@ -562,7 +562,7 @@ std::shared_ptr<::gloo::transport::Device> ProcessGroupGloo:: #endif ProcessGroupGloo::ProcessGroupGloo( - const std::shared_ptr& store, + const c10::intrusive_ptr& store, int rank, int size, Options options) diff --git a/torch/lib/c10d/ProcessGroupGloo.hpp b/torch/lib/c10d/ProcessGroupGloo.hpp index 74fd0f6e5165..0508b6f857a1 100644 --- a/torch/lib/c10d/ProcessGroupGloo.hpp +++ b/torch/lib/c10d/ProcessGroupGloo.hpp @@ -152,7 +152,7 @@ class ProcessGroupGloo : public ProcessGroup { static std::shared_ptr<::gloo::transport::Device> createDefaultDevice(); explicit ProcessGroupGloo( - const std::shared_ptr& store, + const c10::intrusive_ptr& store, int rank, int size, Options options = Options()); diff --git a/torch/lib/c10d/ProcessGroupNCCL.cpp b/torch/lib/c10d/ProcessGroupNCCL.cpp index bd1563226343..acb81d0cad6d 100644 --- a/torch/lib/c10d/ProcessGroupNCCL.cpp +++ b/torch/lib/c10d/ProcessGroupNCCL.cpp @@ -437,7 +437,7 @@ bool ProcessGroupNCCL::WorkNCCL::timedOut() { } ProcessGroupNCCL::ProcessGroupNCCL( - const std::shared_ptr& store, + const c10::intrusive_ptr& store, int rank, int size, Options options) diff --git a/torch/lib/c10d/ProcessGroupNCCL.hpp b/torch/lib/c10d/ProcessGroupNCCL.hpp index 59f06fda1ec1..b93bd0c2d70c 100644 --- a/torch/lib/c10d/ProcessGroupNCCL.hpp +++ b/torch/lib/c10d/ProcessGroupNCCL.hpp @@ -163,7 +163,7 @@ class ProcessGroupNCCL : public ProcessGroup { // Reference to the store so that we can write aborted communicators // to the store. - std::shared_ptr store_; + c10::intrusive_ptr store_; // Store a reference to NCCL collective's outputs to be used by getFuture. std::shared_ptr> outputs_; @@ -393,7 +393,7 @@ class ProcessGroupNCCL : public ProcessGroup { // communicator. These NCCL communicators are cached and reused if possible. // ProcessGroupNCCL( - const std::shared_ptr& store, + const c10::intrusive_ptr& store, int rank, int size, Options options = Options()); @@ -402,7 +402,7 @@ class ProcessGroupNCCL : public ProcessGroup { // If you have existing code that uses the `groupName`, you can replace // it by specifying a `c10d::PrefixStore(groupName, store)` for store. C10_DEPRECATED ProcessGroupNCCL( - const std::shared_ptr& store, + const c10::intrusive_ptr& store, int rank, int size, const std::string& groupName, @@ -594,7 +594,7 @@ class ProcessGroupNCCL : public ProcessGroup { static const int64_t kWorkCleanupThreadSleepMillis; // The store is used to broadcast the NCCL unique ID of rank 0. - std::shared_ptr store_; + c10::intrusive_ptr store_; // The number of NCCL communicators that have been created during // the lifetime of this process group. This sequence number is diff --git a/torch/lib/c10d/Store.hpp b/torch/lib/c10d/Store.hpp index e42bbf300e0b..f97e80013cdb 100644 --- a/torch/lib/c10d/Store.hpp +++ b/torch/lib/c10d/Store.hpp @@ -6,9 +6,11 @@ #include #include +#include + namespace c10d { -class Store { +class Store : public torch::CustomClassHolder { public: static constexpr std::chrono::milliseconds kDefaultTimeout = std::chrono::seconds(300); diff --git a/torch/lib/c10d/frontend.hpp b/torch/lib/c10d/frontend.hpp index 69705427b53c..3449ee30b5ef 100644 --- a/torch/lib/c10d/frontend.hpp +++ b/torch/lib/c10d/frontend.hpp @@ -35,7 +35,7 @@ class DistributedC10d { const std::chrono::milliseconds& timeout, int64_t world_size, int64_t rank, - std::shared_ptr store, + c10::intrusive_ptr store, const std::string& group_name); void destroyProcessGroup(std::shared_ptr group); @@ -202,7 +202,7 @@ class DistributedC10d { // need to use ProcessGroup or ProcesGroup* as key. std::unordered_map< std::shared_ptr, - std::pair>> + std::pair>> pg_map_; // Note, this is different mapping relationship than original Python diff --git a/torch/lib/c10d/test/FileStoreTest.cpp b/torch/lib/c10d/test/FileStoreTest.cpp index cc8da6326091..ce75c78adce7 100644 --- a/torch/lib/c10d/test/FileStoreTest.cpp +++ b/torch/lib/c10d/test/FileStoreTest.cpp @@ -41,7 +41,7 @@ std::string tmppath() { void testGetSet(std::string path, std::string prefix = "") { // Basic Set/Get on File Store { - auto fileStore = std::make_shared(path, 2); + auto fileStore = c10::make_intrusive(path, 2); c10d::PrefixStore store(prefix, fileStore); c10d::test::set(store, "key0", "value0"); c10d::test::set(store, "key1", "value1"); @@ -53,7 +53,7 @@ void testGetSet(std::string path, std::string prefix = "") { // Perform get on new instance { - auto fileStore = std::make_shared(path, 2); + auto fileStore = c10::make_intrusive(path, 2); c10d::PrefixStore store(prefix, fileStore); c10d::test::check(store, "key0", "value0"); } @@ -69,7 +69,8 @@ void stressTestStore(std::string path, std::string prefix = "") { for (auto i = 0; i < numThreads; i++) { threads.push_back(std::thread([&] { - auto fileStore = std::make_shared(path, numThreads + 1); + auto fileStore = + c10::make_intrusive(path, numThreads + 1); c10d::PrefixStore store(prefix, fileStore); sem1.post(); sem2.wait(); @@ -87,7 +88,7 @@ void stressTestStore(std::string path, std::string prefix = "") { // Check that the counter has the expected value { - auto fileStore = std::make_shared(path, numThreads + 1); + auto fileStore = c10::make_intrusive(path, numThreads + 1); c10d::PrefixStore store(prefix, fileStore); std::string expected = std::to_string(numThreads * numIterations); c10d::test::check(store, "counter", expected); diff --git a/torch/lib/c10d/test/HashStoreTest.cpp b/torch/lib/c10d/test/HashStoreTest.cpp index a16f83231a58..24b7fc76a417 100644 --- a/torch/lib/c10d/test/HashStoreTest.cpp +++ b/torch/lib/c10d/test/HashStoreTest.cpp @@ -11,7 +11,7 @@ void testGetSet(std::string prefix = "") { // Basic set/get { - auto hashStore = std::make_shared(); + auto hashStore = c10::make_intrusive(); c10d::PrefixStore store(prefix, hashStore); c10d::test::set(store, "key0", "value0"); c10d::test::set(store, "key1", "value1"); @@ -32,7 +32,7 @@ void testGetSet(std::string prefix = "") { // get() waits up to timeout_. { - auto hashStore = std::make_shared(); + auto hashStore = c10::make_intrusive(); c10d::PrefixStore store(prefix, hashStore); std::thread th([&]() { c10d::test::set(store, "key0", "value0"); }); c10d::test::check(store, "key0", "value0"); @@ -47,7 +47,7 @@ void stressTestStore(std::string prefix = "") { std::vector threads; c10d::test::Semaphore sem1, sem2; - auto hashStore = std::make_shared(); + auto hashStore = c10::make_intrusive(); c10d::PrefixStore store(prefix, hashStore); for (auto i = 0; i < numThreads; i++) { diff --git a/torch/lib/c10d/test/ProcessGroupGlooAsyncTest.cpp b/torch/lib/c10d/test/ProcessGroupGlooAsyncTest.cpp index 1363a842eab3..091ea9b2ad07 100644 --- a/torch/lib/c10d/test/ProcessGroupGlooAsyncTest.cpp +++ b/torch/lib/c10d/test/ProcessGroupGlooAsyncTest.cpp @@ -45,7 +45,7 @@ class AsyncTest { } void start(int rank, int size) { - auto store = std::make_shared<::c10d::FileStore>(path_, size); + auto store = c10::make_intrusive<::c10d::FileStore>(path_, size); // Use tiny timeout to make this test run fast ::c10d::ProcessGroupGloo::Options options; diff --git a/torch/lib/c10d/test/ProcessGroupGlooTest.cpp b/torch/lib/c10d/test/ProcessGroupGlooTest.cpp index de993a1110b4..469cf32a8442 100644 --- a/torch/lib/c10d/test/ProcessGroupGlooTest.cpp +++ b/torch/lib/c10d/test/ProcessGroupGlooTest.cpp @@ -45,7 +45,7 @@ class SignalTest { } c10::intrusive_ptr<::c10d::ProcessGroup::Work> run(int rank, int size) { - auto store = std::make_shared<::c10d::FileStore>(path_, size); + auto store = c10::make_intrusive<::c10d::FileStore>(path_, size); ::c10d::ProcessGroupGloo::Options options; // Set a timeout that is small enough to make this test run fast, but also @@ -101,7 +101,7 @@ c10::intrusive_ptr<::c10d::ProcessGroup::Work> testSignal( class ProcessGroupGlooDelayed : public ::c10d::ProcessGroupGloo { public: ProcessGroupGlooDelayed( - const std::shared_ptr<::c10d::Store>& store, + const c10::intrusive_ptr<::c10d::Store>& store, int rank, int size, Options options) @@ -151,7 +151,7 @@ class CollectiveTest { } void start(int rank, int size, bool delayed) { - auto store = std::make_shared<::c10d::FileStore>(path_, size); + auto store = c10::make_intrusive<::c10d::FileStore>(path_, size); // Set a timeout that is small enough to make this test run fast, but also // make sure that we don't get timeouts in the ProcessGroupGloo constructor. diff --git a/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp b/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp index f1348922e126..e19981c523de 100644 --- a/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp +++ b/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp @@ -37,7 +37,7 @@ class WorkNCCLSimulateErrors : public c10d::ProcessGroupNCCL::WorkNCCL { class ProcessGroupNCCLSimulateErrors : public c10d::ProcessGroupNCCL { public: ProcessGroupNCCLSimulateErrors( - const std::shared_ptr& store, + const c10::intrusive_ptr& store, int rank, int size, c10d::ProcessGroupNCCL::Options opts) @@ -106,7 +106,7 @@ class WorkNCCLTimedoutErrors : public c10d::ProcessGroupNCCL::WorkNCCL { class ProcessGroupNCCLTimedOutErrors : public ProcessGroupNCCLSimulateErrors { public: ProcessGroupNCCLTimedOutErrors( - const std::shared_ptr& store, + const c10::intrusive_ptr& store, int rank, int size, c10d::ProcessGroupNCCL::Options opts) @@ -153,7 +153,7 @@ class ProcessGroupNCCLErrorsTest : public ::testing::Test { void SetUp() override { size_t numDevices = cudaNumDevices(); TemporaryFile file; - store_ = std::make_shared<::c10d::FileStore>(file.path, 1); + store_ = c10::make_intrusive<::c10d::FileStore>(file.path, 1); at::cuda::OptionalCUDAGuard deviceGuard; tensors_.resize(numDevices); @@ -168,7 +168,7 @@ class ProcessGroupNCCLErrorsTest : public ::testing::Test { } std::vector tensors_; - std::shared_ptr<::c10d::FileStore> store_; + c10::intrusive_ptr<::c10d::FileStore> store_; }; TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsBlocking) { diff --git a/torch/lib/c10d/test/ProcessGroupNCCLTest.cpp b/torch/lib/c10d/test/ProcessGroupNCCLTest.cpp index efa96312aba0..fa5e988273fc 100644 --- a/torch/lib/c10d/test/ProcessGroupNCCLTest.cpp +++ b/torch/lib/c10d/test/ProcessGroupNCCLTest.cpp @@ -31,7 +31,7 @@ class NCCLTestBase { } void initialize(int rank, int size) { - auto store = std::make_shared<::c10d::FileStore>(path_, size); + auto store = c10::make_intrusive<::c10d::FileStore>(path_, size); pg_ = std::unique_ptr<::c10d::ProcessGroupNCCL>( new ::c10d::ProcessGroupNCCL(store, rank, size)); diff --git a/torch/lib/c10d/test/TCPStoreTest.cpp b/torch/lib/c10d/test/TCPStoreTest.cpp index 0cfa72c7801a..8073ec0345e0 100644 --- a/torch/lib/c10d/test/TCPStoreTest.cpp +++ b/torch/lib/c10d/test/TCPStoreTest.cpp @@ -16,7 +16,7 @@ void testHelper(const std::string& prefix = "") { const auto numThreads = 16; const auto numWorkers = numThreads + 1; - auto serverTCPStore = std::make_shared( + auto serverTCPStore = c10::make_intrusive( "127.0.0.1", 0, numWorkers, @@ -25,7 +25,7 @@ void testHelper(const std::string& prefix = "") { /* wait */ false); auto serverStore = - std::make_unique(prefix, serverTCPStore); + c10::make_intrusive(prefix, serverTCPStore); // server store auto serverThread = std::thread([&serverStore, &serverTCPStore] { // Wait for all workers to join. @@ -64,13 +64,13 @@ void testHelper(const std::string& prefix = "") { c10d::test::Semaphore sem1, sem2; // Each thread will have a client store to send/recv data - std::vector> clientTCPStores; - std::vector> clientStores; + std::vector> clientTCPStores; + std::vector> clientStores; for (auto i = 0; i < numThreads; i++) { - clientTCPStores.push_back(std::make_unique( + clientTCPStores.push_back(c10::make_intrusive( "127.0.0.1", serverTCPStore->getPort(), numWorkers, false)); - clientStores.push_back(std::unique_ptr( - new c10d::PrefixStore(prefix, clientTCPStores[i]))); + clientStores.push_back( + c10::make_intrusive(prefix, clientTCPStores[i])); } std::string expectedCounterRes = std::to_string(numThreads * numIterations + 1); From ae5c2febb912066f1a8dec8b54451c09195b2c6d Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Tue, 10 Nov 2020 23:27:21 -0800 Subject: [PATCH 03/93] [c10d] switch ProcessGroupNCCL:Options to be managed by intrusive_ptr (#47075) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47075 Test Plan: Imported from OSS Reviewed By: pritamdamania87 Differential Revision: D24667127 Pulled By: wanchaol fbshipit-source-id: 54986193ba1b22480622a2e9d6d41d9472d201f3 --- torch/csrc/distributed/c10d/init.cpp | 12 +++++++----- torch/lib/c10d/ProcessGroupNCCL.cpp | 6 +++--- torch/lib/c10d/ProcessGroupNCCL.hpp | 16 +++++++++++++--- .../lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp | 16 ++++++++-------- 4 files changed, 31 insertions(+), 19 deletions(-) diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index e9d8f618eb21..dd32ff91603c 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -1000,16 +1000,17 @@ that adds a prefix to each key inserted to the store. const c10::intrusive_ptr<::c10d::Store>&, int, int, - ::c10d::ProcessGroupNCCL::Options>(), + const c10::intrusive_ptr< + ::c10d::ProcessGroupNCCL::Options>&>(), py::call_guard()) .def( py::init([](const c10::intrusive_ptr<::c10d::Store>& store, int rank, int size, const std::chrono::milliseconds& timeout) { - ::c10d::ProcessGroupNCCL::Options options; - options.isHighPriorityStream = false; - options.opTimeout = timeout; + auto options = ::c10d::ProcessGroupNCCL::Options::create(); + options->isHighPriorityStream = false; + options->opTimeout = timeout; return std::make_shared<::c10d::ProcessGroupNCCL>( store, rank, size, options); }), @@ -1020,7 +1021,8 @@ that adds a prefix to each key inserted to the store. ::c10d::ProcessGroupNCCL::kProcessGroupNCCLOpTimeoutMillis), py::call_guard()); - py::class_<::c10d::ProcessGroupNCCL::Options>(processGroupNCCL, "Options") + intrusive_ptr_class_<::c10d::ProcessGroupNCCL::Options>( + processGroupNCCL, "Options") .def(py::init<>()) .def_readwrite( "is_high_priority", diff --git a/torch/lib/c10d/ProcessGroupNCCL.cpp b/torch/lib/c10d/ProcessGroupNCCL.cpp index acb81d0cad6d..0b1a4c9f34e6 100644 --- a/torch/lib/c10d/ProcessGroupNCCL.cpp +++ b/torch/lib/c10d/ProcessGroupNCCL.cpp @@ -440,14 +440,14 @@ ProcessGroupNCCL::ProcessGroupNCCL( const c10::intrusive_ptr& store, int rank, int size, - Options options) + const c10::intrusive_ptr& options) : ProcessGroup(rank, size), store_(store), ncclCommCounter_(0), terminateProcessGroup_(false), - opTimeout_(options.opTimeout), + opTimeout_(options->opTimeout), futureNCCLCallbackStreams_(c10::cuda::device_count()), - isHighPriorityStream_(options.isHighPriorityStream) { + isHighPriorityStream_(options->isHighPriorityStream) { TORCH_CHECK(at::cuda::getNumGPUs() != 0, "ProcessGroupNCCL is only supported with GPUs, no GPUs found!"); blockingWait_ = parseEnvVarFlag(NCCL_BLOCKING_WAIT); diff --git a/torch/lib/c10d/ProcessGroupNCCL.hpp b/torch/lib/c10d/ProcessGroupNCCL.hpp index b93bd0c2d70c..b84cc4deb051 100644 --- a/torch/lib/c10d/ProcessGroupNCCL.hpp +++ b/torch/lib/c10d/ProcessGroupNCCL.hpp @@ -1,5 +1,6 @@ #pragma once +#include #include #include #include @@ -17,6 +18,8 @@ #include #include +#include + namespace c10d { // Environment variable which controls whether or not wait() is blocking or @@ -175,9 +178,16 @@ class ProcessGroupNCCL : public ProcessGroup { friend class ProcessGroupNCCL; }; - struct Options { + struct Options : torch::CustomClassHolder { explicit Options(); + // return intrusive_ptr of the object + static c10::intrusive_ptr create( + std::chrono::milliseconds timeout = kNoTimeout, + bool isHighStream = false) { + return c10::make_intrusive(); + } + std::chrono::milliseconds opTimeout; bool isHighPriorityStream; }; @@ -396,7 +406,7 @@ class ProcessGroupNCCL : public ProcessGroup { const c10::intrusive_ptr& store, int rank, int size, - Options options = Options()); + const c10::intrusive_ptr& options = Options::create()); // This constructor includes the deprecated `groupName` argument. // If you have existing code that uses the `groupName`, you can replace @@ -406,7 +416,7 @@ class ProcessGroupNCCL : public ProcessGroup { int rank, int size, const std::string& groupName, - Options options = Options()) + const c10::intrusive_ptr& options = Options::create()) : ProcessGroupNCCL(store, rank, size, options) {} virtual ~ProcessGroupNCCL(); diff --git a/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp b/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp index e19981c523de..82ca25049c63 100644 --- a/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp +++ b/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp @@ -40,7 +40,7 @@ class ProcessGroupNCCLSimulateErrors : public c10d::ProcessGroupNCCL { const c10::intrusive_ptr& store, int rank, int size, - c10d::ProcessGroupNCCL::Options opts) + const c10::intrusive_ptr& opts) : ProcessGroupNCCL(store, rank, size, opts), simulate_error_(false) {} std::exception_ptr checkForNCCLErrors( @@ -109,7 +109,7 @@ class ProcessGroupNCCLTimedOutErrors : public ProcessGroupNCCLSimulateErrors { const c10::intrusive_ptr& store, int rank, int size, - c10d::ProcessGroupNCCL::Options opts) + const c10::intrusive_ptr& opts) : ProcessGroupNCCLSimulateErrors(store, rank, size, opts), set_timedout_error_(false) {} @@ -177,8 +177,8 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsBlocking) { } ASSERT_TRUE(setenv(c10d::NCCL_BLOCKING_WAIT, "1", 1) == 0); - c10d::ProcessGroupNCCL::Options options; - options.opTimeout = std::chrono::milliseconds(1000); + auto options = c10d::ProcessGroupNCCL::Options::create(); + options->opTimeout = std::chrono::milliseconds(1000); ProcessGroupNCCLSimulateErrors pg( store_, 0, 1, options); @@ -206,8 +206,8 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLTimedoutErrorsBlocking) { } ASSERT_TRUE(setenv(c10d::NCCL_BLOCKING_WAIT, "1", 1) == 0); - c10d::ProcessGroupNCCL::Options options; - options.opTimeout = std::chrono::milliseconds(3000); + auto options = c10d::ProcessGroupNCCL::Options::create(); + options->opTimeout = std::chrono::milliseconds(3000); ProcessGroupNCCLTimedOutErrors pg( store_, 0, 1, options); @@ -229,8 +229,8 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNonBlocking) { return; } - c10d::ProcessGroupNCCL::Options options; - options.opTimeout = std::chrono::milliseconds(3000); + auto options = c10d::ProcessGroupNCCL::Options::create(); + options->opTimeout = std::chrono::milliseconds(3000); ProcessGroupNCCLSimulateErrors pg( store_, 0, 1, options); From 5647f0ca7c306002bf8edb73c3461af1778f19e0 Mon Sep 17 00:00:00 2001 From: Erjia Guan Date: Wed, 11 Nov 2020 07:40:55 -0800 Subject: [PATCH 04/93] Revert D24859919: [pytorch][PR] Grammatically updated the tech docs Test Plan: revert-hammer Differential Revision: D24859919 (https://github.com/pytorch/pytorch/commit/a843d48ead119eb96a8dc7ed93202cf6f3558009) Original commit changeset: 5c6a8bc8e785 fbshipit-source-id: f757995fb64cfd4212c978618d572367e7296758 --- torch/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch/__init__.py b/torch/__init__.py index b0e976a2bb6a..4d84b2ce31b5 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -1,9 +1,9 @@ r""" -The Torch package contains data structures for multi-dimensional +The torch package contains data structures for multi-dimensional tensors and defines mathematical operations over these tensors. Additionally, it provides many utilities for efficient serializing of -Tensors and arbitrary types, with other useful utilities. +Tensors and arbitrary types, and other useful utilities. It has a CUDA counterpart, that enables you to run your tensor computations on an NVIDIA GPU with compute capability >= 3.0. From 88ec72e1c2a4b1e2a15cbe4703b9567bf9369a09 Mon Sep 17 00:00:00 2001 From: Xingying Cheng Date: Wed, 11 Nov 2020 08:09:00 -0800 Subject: [PATCH 05/93] [fbcode][pytorch mobile] Create model reader utilities. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: For some of the end to end flow projects, we will need the capabilities to read module information during model validation or model publishing. Creating this model_reader.py for utilities for model content reading, this diff we included the following functionalities: 1. read the model bytecode version; 2. check if a model is lite PyTorch script module; 3. check if a model is PyTorch script module. This diff is recreated from the reverted diff: D24655999 (https://github.com/pytorch/pytorch/commit/7f056e99dd10cff3dab0a6e1931335a3bb2a1ce4). Test Plan: ``` [xcheng16@devvm1099]/data/users/xcheng16/fbsource/fbcode% buck test //caffe2/torch/fb/mobile/tests:mobile_model_reader_tests Action graph will be rebuilt because files have been added or removed. Parsing buck files: finished in 10.4 sec Creating action graph: finished in 22.2 sec Building: finished in 01:29.1 min (100%) 10619/10619 jobs, 1145 updated Total time: 02:01.8 min More details at https://www.internalfb.com/intern/buck/build/f962dfad-76f9-457a-aca3-768ce20f0c31 Tpx test run coordinator for Facebook. See https://fburl.com/tpx for details. Running with tpx session id: 172633f6-6b5b-49e9-a632-b4efa083a001 Trace available for this run at /tmp/tpx-20201109-165156.109798/trace.log Started reporting to test run: https://our.intern.facebook.com/intern/testinfra/testrun/3940649712677511 ✓ ListingSuccess: caffe2/torch/fb/mobile/tests:mobile_model_reader_tests - main (18.229) ✓ Pass: caffe2/torch/fb/mobile/tests:mobile_model_reader_tests - test_is_pytorch_lite_module (caffe2.torch.fb.mobile.tests.test_model_reader.TestModelLoader) (8.975) ✓ Pass: caffe2/torch/fb/mobile/tests:mobile_model_reader_tests - test_is_pytorch_script_module (caffe2.torch.fb.mobile.tests.test_model_reader.TestModelLoader) (9.136) ✓ Pass: caffe2/torch/fb/mobile/tests:mobile_model_reader_tests - test_read_module_bytecode_version (caffe2.torch.fb.mobile.tests.test_model_reader.TestModelLoader) (9.152) Summary Pass: 3 ListingSuccess: 1 Finished test run: https://our.intern.facebook.com/intern/testinfra/testrun/3940649712677511 ``` Reviewed By: husthyc Differential Revision: D24848563 fbshipit-source-id: ab3371e111206a4bb4d07715c3314596cdc38d2c --- torch/utils/show_pickle.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch/utils/show_pickle.py b/torch/utils/show_pickle.py index 0e2498d64c56..9e55ebff48b9 100644 --- a/torch/utils/show_pickle.py +++ b/torch/utils/show_pickle.py @@ -68,6 +68,7 @@ def persistent_load(self, pid): def dump(cls, in_stream, out_stream): value = cls(in_stream).load() pprint.pprint(value, stream=out_stream) + return value def main(argv, output_stream=None): From 48ed577fbd0876c2ab1432dfaf8c3009067828ff Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Wed, 11 Nov 2020 08:45:02 -0800 Subject: [PATCH 06/93] Stop including TypeDefault.h from MPSCNNTests.mm (#46998) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/46998 It's not using any TypeDefault symbols directly; running CI to see if it was being included for other headers. Signed-off-by: Edward Z. Yang Test Plan: Imported from OSS Reviewed By: ngimel Differential Revision: D24621920 Pulled By: ezyang fbshipit-source-id: f868e5412ff3e5a616c3fc38110f203ca545eed5 --- aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.mm | 1 - 1 file changed, 1 deletion(-) diff --git a/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.mm b/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.mm index e56845fd1f9e..bbcbfe10fd01 100644 --- a/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.mm +++ b/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.mm @@ -7,7 +7,6 @@ #import #include -#include #import #include From 4cb73f5a4c841ffa0f20a27d920173e16549c4f7 Mon Sep 17 00:00:00 2001 From: Ansley Ussery Date: Wed, 11 Nov 2020 08:52:06 -0800 Subject: [PATCH 07/93] Allow for string literal return during symbolic tracing (#47618) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47618 Test Plan: Imported from OSS Reviewed By: jamesr66a Differential Revision: D24870422 Pulled By: ansley fbshipit-source-id: 41c56c2f4f1f7bb360cea0fb346f6e4d495f5c2b --- test/test_fx.py | 11 +++++++++++ torch/fx/graph.py | 2 +- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/test/test_fx.py b/test/test_fx.py index 1796ad2e87ef..dcb104528402 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -1141,6 +1141,17 @@ def forward(self, x, y=1): self.checkGraphModule(m, (2,)) self.checkGraphModule(m, (2, 3)) + def test_string_literal_return(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self): + return "foo" + + m = M() + self.checkGraphModule(m, ()) + if __name__ == '__main__': run_tests() diff --git a/torch/fx/graph.py b/torch/fx/graph.py index d737a1a65629..dd07ff7a508e 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -385,7 +385,7 @@ def type_repr(o : Any): elif node.op == 'output': if node.type is not None: maybe_return_annotation = f" -> {type_repr(node.type)}" - body.append(f'return {node.args[0]}') + body.append(f'return {repr(node.args[0])}') continue raise NotImplementedError(f'node: {node.op} {node.target}') From 1239d067aebc314d5eb488026d271d24cb9f2ca4 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 11 Nov 2020 09:13:12 -0800 Subject: [PATCH 08/93] [quant][graphmode][fx] Support standalone_module_class (#47705) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47705 Test Plan: Imported from OSS Reviewed By: z-a-f Differential Revision: D24872380 fbshipit-source-id: db2ec7ba03da27203033fbebc11666be572622bb --- test/quantization/test_quantize_fx.py | 86 ++++++++++++++------------- torch/quantization/fx/quantize.py | 20 +++++-- torch/quantization/quantize_fx.py | 8 +++ 3 files changed, 67 insertions(+), 47 deletions(-) diff --git a/test/quantization/test_quantize_fx.py b/test/quantization/test_quantize_fx.py index e057e25643a4..16694b0f0356 100644 --- a/test/quantization/test_quantize_fx.py +++ b/test/quantization/test_quantize_fx.py @@ -500,48 +500,52 @@ def forward(self, x): original_ref_m.conv2.bias = torch.nn.Parameter(original_m.standalone.conv.bias.detach()) qconfig_dict = {"": default_qconfig} - prepare_custom_config_dict = {"standalone_module_name": ["standalone"]} - # check prepared model - m = prepare_fx( - original_m, qconfig_dict, prepare_custom_config_dict=prepare_custom_config_dict) - # calibration - m(data) - # input and output of first conv, observer for standalone module - # will be inserted in the standalone module itself - count_check = { - ns.call_module(torch.quantization.MinMaxObserver): 2 - } - self.checkGraphModuleNodes(m, expected_node_occurrence=count_check) - # for output of conv in the standalone module - count_check = { - ns.call_module(torch.quantization.MinMaxObserver): 1 - } - self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=count_check) + config_name = {"standalone_module_name": ["standalone"]} + config_class = {"standalone_module_class": [StandaloneModule]} + for prepare_config in [config_name, config_class]: + original_m_copy = copy.deepcopy(original_m) + original_ref_m_copy = copy.deepcopy(original_ref_m) + # check prepared model + m = prepare_fx( + original_m_copy, qconfig_dict, prepare_custom_config_dict=prepare_config) + # calibration + m(data) + # input and output of first conv, observer for standalone module + # will be inserted in the standalone module itself + count_check = { + ns.call_module(torch.quantization.MinMaxObserver): 2 + } + self.checkGraphModuleNodes(m, expected_node_occurrence=count_check) + # for output of conv in the standalone module + count_check = { + ns.call_module(torch.quantization.MinMaxObserver): 1 + } + self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=count_check) - # check converted/quantized model - m = convert_fx(m) - count_check = { - ns.call_function(torch.quantize_per_tensor) : 1, - ns.call_module(nnq.Conv2d) : 1, - ns.call_method('dequantize') : 1, - } - self.checkGraphModuleNodes(m, expected_node_occurrence=count_check) - count_check = { - # quantization of input happens in parent module - # quantization of output happens in the quantized conv module - ns.call_function(torch.quantize_per_tensor) : 0, - # dequantization for output happens in parent module - ns.call_method('dequantize') : 0, - } - self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=count_check) - res = m(data) - - # quantize the reference model - ref_m = prepare_fx(original_ref_m, qconfig_dict) - ref_m(data) - ref_m = convert_fx(ref_m) - ref_res = ref_m(data) - self.assertEqual(res, ref_res) + # check converted/quantized model + m = convert_fx(m) + count_check = { + ns.call_function(torch.quantize_per_tensor) : 1, + ns.call_module(nnq.Conv2d) : 1, + ns.call_method('dequantize') : 1, + } + self.checkGraphModuleNodes(m, expected_node_occurrence=count_check) + count_check = { + # quantization of input happens in parent module + # quantization of output happens in the quantized conv module + ns.call_function(torch.quantize_per_tensor) : 0, + # dequantization for output happens in parent module + ns.call_method('dequantize') : 0, + } + self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=count_check) + res = m(data) + + # quantize the reference model + ref_m = prepare_fx(original_ref_m_copy, qconfig_dict) + ref_m(data) + ref_m = convert_fx(ref_m) + ref_res = ref_m(data) + self.assertEqual(res, ref_res) @skipIfNoFBGEMM def test_qconfig_none(self): diff --git a/torch/quantization/fx/quantize.py b/torch/quantization/fx/quantize.py index 35e3e1ac8efb..b87612c97dbe 100644 --- a/torch/quantization/fx/quantize.py +++ b/torch/quantization/fx/quantize.py @@ -347,9 +347,10 @@ def _prepare(self, model, qconfig_dict, prepare_custom_config_dict, is_standalon # match the patterns that will get quantized standalone_module_names = prepare_custom_config_dict.get("standalone_module_name", None) + standalone_module_classes = prepare_custom_config_dict.get("standalone_module_class", None) custom_module_classes = get_custom_module_class_keys(prepare_custom_config_dict, "float_to_observed_custom_module_class") matches = self._find_matches( - model.graph, self.modules, self.patterns, standalone_module_names, custom_module_classes) + model.graph, self.modules, self.patterns, standalone_module_names, standalone_module_classes, custom_module_classes) # find _inputs_ to matched nodes that are not quantized, these # have to be quantized, which requires measuring stats, @@ -826,7 +827,9 @@ def convert(self, model, debug=False, convert_custom_config_dict=None, is_standa def _find_matches( self, graph, modules, patterns, - standalone_module_names=None, custom_module_classes=None): + standalone_module_names=None, + standalone_module_classes=None, + custom_module_classes=None): """ Matches the nodes in the input graph to quantization patterns, and outputs the information needed to quantize them in future steps. @@ -850,6 +853,12 @@ def _find_matches( if custom_module_classes is None: custom_module_classes = [] + if standalone_module_classes is None: + standalone_module_classes = [] + + if standalone_module_names is None: + standalone_module_names = [] + match_map = {} all_matched = set() @@ -883,10 +892,9 @@ def record_match(pattern, node, matched): match_map[node.name] = ( node, [node], None, CustomModuleQuantizeHandler(self, node), custom_module_qconfig) - def is_standalone_module(module_path): - if standalone_module_names is None: - return False - return module_path in standalone_module_names + def is_standalone_module(node_target): + return node_target in standalone_module_names or \ + type(self.modules[node_target]) in standalone_module_classes # add standalone modules to the match for node in graph.nodes: diff --git a/torch/quantization/quantize_fx.py b/torch/quantization/quantize_fx.py index 93043559bf48..91d58c2966a4 100644 --- a/torch/quantization/quantize_fx.py +++ b/torch/quantization/quantize_fx.py @@ -75,6 +75,9 @@ def _prepare_fx(model, qconfig_dict, prepare_custom_config_dict=None, is_standal # standalone module and custom module config are applied in top level module standalone_module_names = prepare_custom_config_dict.get('standalone_module_name', []) skipped_module_names += standalone_module_names + + standalone_module_classes = prepare_custom_config_dict.get('standalone_module_class', []) + skipped_module_classes += standalone_module_classes float_custom_module_classes = get_custom_module_class_keys( prepare_custom_config_dict, "float_to_observed_custom_module_class") skipped_module_classes += float_custom_module_classes @@ -170,6 +173,11 @@ def prepare_fx(model, qconfig_dict, prepare_custom_config_dict=None): "standalone_module_name": [ "submodule.standalone" ], + + "standalone_module_class": [ + StandaloneModule + ], + # user will manually define the corresponding observed # module class which has a from_float class method that converts # float custom module to observed custom module From d478605dec65a746d41506b23693d6013bfa11b2 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi Date: Wed, 11 Nov 2020 09:23:02 -0800 Subject: [PATCH 09/93] Fix classmethod override argument passing. (#47114) Summary: Fixes https://github.com/pytorch/pytorch/issues/47069. Fixes https://github.com/pytorch/pytorch/issues/46824. Fixes https://github.com/pytorch/pytorch/issues/47186 Pull Request resolved: https://github.com/pytorch/pytorch/pull/47114 Reviewed By: ngimel Differential Revision: D24649598 Pulled By: ezyang fbshipit-source-id: af077affece7eceb1e4faf9c94d15484796b0f0e --- test/test_overrides.py | 7 ++++ .../templates/python_variable_methods.cpp | 36 +++++++++---------- torch/csrc/utils/python_arg_parser.cpp | 10 +++--- torch/csrc/utils/python_arg_parser.h | 4 +-- 4 files changed, 32 insertions(+), 25 deletions(-) diff --git a/test/test_overrides.py b/test/test_overrides.py index a0565ec30c8d..f12d9ace9cbd 100644 --- a/test/test_overrides.py +++ b/test/test_overrides.py @@ -831,5 +831,12 @@ def test_max(self): self.assertEqual(type(r), type(rs)) self.assertEqual(r, rs) +class TestGradNewOnesOverride(TestCase): + """ Regression test for gh-47069 """ + def test_newones(self): + t = torch.tensor([1, 2]).as_subclass(SubTensor2) + n = t.new_ones((1, 2)) + self.assertEqual(type(n), SubTensor2) + if __name__ == '__main__': unittest.main() diff --git a/tools/autograd/templates/python_variable_methods.cpp b/tools/autograd/templates/python_variable_methods.cpp index 4b441d6f3616..15fd600e441c 100644 --- a/tools/autograd/templates/python_variable_methods.cpp +++ b/tools/autograd/templates/python_variable_methods.cpp @@ -56,7 +56,7 @@ static PyObject * THPVariable__is_view(PyObject *self, PyObject* args) { HANDLE_TH_ERRORS if (check_has_torch_function(self)) { - return handle_torch_function(self, "_is_view"); + return handle_torch_function(self, "_is_view", args); } auto& self_ = reinterpret_cast(self)->cdata; if (self_.is_view()) { @@ -160,7 +160,7 @@ static PyObject * THPVariable_get_device(PyObject* self_, PyObject* args) { HANDLE_TH_ERRORS if (check_has_torch_function(self_)) { - return handle_torch_function(self_, "get_device"); + return handle_torch_function(self_, "get_device", args, nullptr); } auto& self = reinterpret_cast(self_)->cdata; return wrap(self.get_device()); @@ -171,7 +171,7 @@ static PyObject * THPVariable_has_names(PyObject* self_, PyObject* args) { HANDLE_TH_ERRORS if (check_has_torch_function(self_)) { - return handle_torch_function(self_, "has_names"); + return handle_torch_function(self_, "has_names", args); } auto& self = reinterpret_cast(self_)->cdata; return wrap(self.has_names()); @@ -183,7 +183,7 @@ static PyObject * THPVariable_data_ptr(PyObject* self_, PyObject* args) { HANDLE_TH_ERRORS if (check_has_torch_function(self_)) { - return handle_torch_function(self_, "data_ptr"); + return handle_torch_function(self_, "data_ptr", args); } auto& self = reinterpret_cast(self_)->cdata; return wrap(self.data_ptr()); @@ -207,7 +207,7 @@ static PyObject * THPVariable_dim(PyObject* self, PyObject* args) { HANDLE_TH_ERRORS if (check_has_torch_function(self)) { - return handle_torch_function(self, "dim"); + return handle_torch_function(self, "dim", args); } auto& self_ = reinterpret_cast(self)->cdata; return THPUtils_packInt64(self_.dim()); @@ -219,7 +219,7 @@ static PyObject * THPVariable_numel(PyObject* self, PyObject* args) { HANDLE_TH_ERRORS if (check_has_torch_function(self)) { - return handle_torch_function(self, "numel"); + return handle_torch_function(self, "numel", args); } auto& self_ = reinterpret_cast(self)->cdata; return THPUtils_packInt64(self_.numel()); @@ -333,7 +333,7 @@ static bool dispatch_to_Bool(const Tensor & self) { static PyObject * THPVariable_float_scalar(PyObject* self, PyObject* args) { HANDLE_TH_ERRORS if (check_has_torch_function(self)) { - return handle_torch_function(self, "__float__"); + return handle_torch_function(self, "__float__", args); } jit::tracer::warn("Converting a tensor to a Python float", jit::tracer::WARN_PYTHON_DATAFLOW); auto& self_ = reinterpret_cast(self)->cdata; @@ -344,7 +344,7 @@ static PyObject * THPVariable_float_scalar(PyObject* self, PyObject* args) { static PyObject * THPVariable_complex_scalar(PyObject* self, PyObject* args) { HANDLE_TH_ERRORS if (check_has_torch_function(self)) { - return handle_torch_function(self, "__complex__"); + return handle_torch_function(self, "__complex__", args); } jit::tracer::warn("Converting a tensor to a Python complex", jit::tracer::WARN_PYTHON_DATAFLOW); auto& self_ = reinterpret_cast(self)->cdata; @@ -355,7 +355,7 @@ static PyObject * THPVariable_complex_scalar(PyObject* self, PyObject* args) { static PyObject * THPVariable_integral_scalar(PyObject* self, PyObject* args) { HANDLE_TH_ERRORS if (check_has_torch_function(self)) { - return handle_torch_function(self, "__int__"); + return handle_torch_function(self, "__int__", args); } jit::tracer::warn("Converting a tensor to a Python integer", jit::tracer::WARN_PYTHON_DATAFLOW); auto& self_ = reinterpret_cast(self)->cdata; @@ -374,7 +374,7 @@ static PyObject * THPVariable_integral_scalar(PyObject* self, PyObject* args) { static PyObject * THPVariable_index_scalar(PyObject* self, PyObject* args) { HANDLE_TH_ERRORS if (check_has_torch_function(self)) { - return handle_torch_function(self, "__index__"); + return handle_torch_function(self, "__index__", args); } jit::tracer::warn("Converting a tensor to a Python index", jit::tracer::WARN_PYTHON_DATAFLOW); auto& self_ = reinterpret_cast(self)->cdata; @@ -396,7 +396,7 @@ static Tensor dispatch_invert(const Tensor & self) { static PyObject * THPVariable_invert(PyObject* self, PyObject* args) { HANDLE_TH_ERRORS if (check_has_torch_function(self)) { - return handle_torch_function(self, "__invert__"); + return handle_torch_function(self, "__invert__", args); } auto& self_ = reinterpret_cast(self)->cdata; if (!isIntegralType(self_.scalar_type(), /*includeBool=*/true)) { @@ -691,7 +691,7 @@ static PyObject * THPVariable_element_size(PyObject* self, PyObject* args) { HANDLE_TH_ERRORS if (check_has_torch_function(self)) { - return handle_torch_function(self, "element_size"); + return handle_torch_function(self, "element_size", args); } auto& self_ = reinterpret_cast(self)->cdata; return THPUtils_packInt64(self_.element_size()); @@ -769,7 +769,7 @@ static PyObject * THPVariable_item(PyObject* self, PyObject* args) { HANDLE_TH_ERRORS if (check_has_torch_function(self)) { - return handle_torch_function(self, "item"); + return handle_torch_function(self, "item", args); } jit::tracer::warn("Converting a tensor to a Python number", jit::tracer::WARN_PYTHON_DATAFLOW); auto& self_ = reinterpret_cast(self)->cdata; @@ -838,7 +838,7 @@ static PyObject * THPVariable_new(PyObject* self, PyObject* args, PyObject* kwar { HANDLE_TH_ERRORS if (check_has_torch_function(self)) { - return handle_torch_function(self, "new"); + return handle_torch_function(self, "new", args, kwargs); } auto& self_ = reinterpret_cast(self)->cdata; OptionalDeviceGuard device_guard(device_of(self_)); @@ -850,7 +850,7 @@ static PyObject * THPVariable_new_ones(PyObject* self, PyObject* args, PyObject* { HANDLE_TH_ERRORS if (check_has_torch_function(self)) { - return handle_torch_function(self, "new_ones"); + return handle_torch_function(self, "new_ones", args, kwargs); } auto& self_ = reinterpret_cast(self)->cdata; OptionalDeviceGuard device_guard(device_of(self_)); @@ -862,7 +862,7 @@ static PyObject * THPVariable_new_tensor(PyObject* self, PyObject* args, PyObjec { HANDLE_TH_ERRORS if (check_has_torch_function(self)) { - return handle_torch_function(self, "new_tensor"); + return handle_torch_function(self, "new_tensor", args, kwargs); } auto& self_ = reinterpret_cast(self)->cdata; OptionalDeviceGuard device_guard(device_of(self_)); @@ -941,7 +941,7 @@ static PyObject * THPVariable_tolist(PyObject* self, PyObject* args) { HANDLE_TH_ERRORS if (check_has_torch_function(self)) { - return handle_torch_function(self, "tolist"); + return handle_torch_function(self, "tolist", args); } jit::tracer::warn("Converting a tensor to a Python list", jit::tracer::WARN_PYTHON_DATAFLOW); auto self_ = reinterpret_cast(self)->cdata; @@ -1010,7 +1010,7 @@ static PyObject * THPVariable_type(PyObject* self, PyObject* args, PyObject* kwa static PyObject * THPVariable_bool_scalar(PyObject* self, PyObject* args) { if (check_has_torch_function(self)) { HANDLE_TH_ERRORS - return handle_torch_function(self, "__bool__"); + return handle_torch_function(self, "__bool__", args); END_HANDLE_TH_ERRORS } jit::tracer::warn("Converting a tensor to a Python boolean", jit::tracer::WARN_PYTHON_DATAFLOW); diff --git a/torch/csrc/utils/python_arg_parser.cpp b/torch/csrc/utils/python_arg_parser.cpp index ff94b1f5ceca..950e7d9fb82d 100644 --- a/torch/csrc/utils/python_arg_parser.cpp +++ b/torch/csrc/utils/python_arg_parser.cpp @@ -139,7 +139,7 @@ FunctionParameter::FunctionParameter(const std::string& fmt, bool keyword_only) auto handle_torch_function_getter(THPVariable* self, const std::string& property_name) -> PyObject* { py::object torch_api = PyObject_FastGetAttrString(THPVariableClass, (char*)property_name.c_str()); std::string module_name = "torch.Tensor." + property_name; - return handle_torch_function((PyObject *)self, "__get__", nullptr, torch_api.ptr(), module_name); + return handle_torch_function((PyObject *)self, "__get__", nullptr, nullptr, torch_api.ptr(), module_name); } auto handle_torch_function_setter(THPVariable* self, const std::string& property_name, PyObject* value) -> int { @@ -148,10 +148,10 @@ auto handle_torch_function_setter(THPVariable* self, const std::string& property if (value != nullptr) { py::tuple args_ = py::make_tuple(py::handle(value)); - handle_torch_function((PyObject *)self, "__set__", args_.ptr(), torch_api.ptr(), module_name); + handle_torch_function((PyObject *)self, "__set__", args_.ptr(), nullptr, torch_api.ptr(), module_name); } else { - handle_torch_function((PyObject *)self, "__delete__", nullptr, torch_api.ptr(), module_name); + handle_torch_function((PyObject *)self, "__delete__", nullptr, nullptr, torch_api.ptr(), module_name); } return 0; } @@ -175,13 +175,13 @@ auto combine_self_args(PyObject *self, PyObject *args) -> py::tuple { return args_; } -auto handle_torch_function(PyObject* self, const std::string& func_name, PyObject* args, PyObject* torch_api, const std::string& module_name) -> PyObject* { +auto handle_torch_function(PyObject* self, const std::string& func_name, PyObject* args, PyObject* kwargs, PyObject* torch_api, const std::string& module_name) -> PyObject* { py::object torch_api_function = PyObject_FastGetAttrString(torch_api, (char*)func_name.c_str()); TORCH_INTERNAL_ASSERT(torch_api_function.ptr() != nullptr, "torch API function must exist"); py::tuple args_ = combine_self_args(self, args); py::tuple py_types = py::make_tuple(py::handle(PyObject_Type(self))); py::object torch_function = PyObject_FastGetAttrString(self, "__torch_function__"); - py::object ret = py::reinterpret_steal(PyObject_CallFunctionObjArgs(torch_function.ptr(), torch_api_function.ptr(), py_types.ptr(), args_.ptr(), NULL)); + py::object ret = py::reinterpret_steal(PyObject_CallFunctionObjArgs(torch_function.ptr(), torch_api_function.ptr(), py_types.ptr(), args_.ptr(), kwargs)); if (ret.ptr() == nullptr) { // if an exception occurred in a user's implementation of // __torch_function__, throw it diff --git a/torch/csrc/utils/python_arg_parser.h b/torch/csrc/utils/python_arg_parser.h index 773486f30ee1..b0b81a9517da 100644 --- a/torch/csrc/utils/python_arg_parser.h +++ b/torch/csrc/utils/python_arg_parser.h @@ -820,8 +820,8 @@ auto handle_torch_function(PythonArgs &r, PyObject* self, PyObject* args, PyObje // Used for functions which needs to parse python args. auto handle_torch_function(PythonArgs &r, PyObject* args, PyObject* kwargs, PyObject* torch_api, const char* module_name) -> PyObject*; -// Used for functions that accept no keyword arguments and have no argument parsing -auto handle_torch_function(PyObject* self, const std::string& func_name, PyObject* args=nullptr, PyObject* torch_api=THPVariableClass, const std::string& module_name="torch.Tensor") -> PyObject*; +// Used for functions that have no argument parsing. +auto handle_torch_function(PyObject* self, const std::string& func_name, PyObject* args=nullptr, PyObject* kwargs=nullptr, PyObject* torch_api=THPVariableClass, const std::string& module_name="torch.Tensor") -> PyObject*; // Used for functions created in C++, e.g., C++ custom op, which doesn't use PythonArgParser to get overloaded_args. auto handle_torch_function_no_python_arg_parser(const std::vector &overloaded_args, PyObject* args, PyObject* kwargs, const char* func_name, PyObject* torch_api_function, const char* module_name) -> PyObject*; From 0c64f9f52614bf8fd7a75f1ea343f0225548fde0 Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Wed, 11 Nov 2020 10:22:43 -0800 Subject: [PATCH 10/93] Convert from higher order functions to classes in tools.codegen.gen (#47008) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47008 bhosmer has been complaining about how it is difficult to distinguish between local variables and closed over variables in the higher order functions. Well, closures and objects do basically the same thing, so just convert all these HOFs into objects. The decoder ring: - Higher order function => Constructor for object - Access to closed over variable => Access to member variable on object - with_native_function => method_with_native_function (because it's hard writing decorators that work for both functions and methods) I didn't even have to change indentation (much). When there is no need for closed over variables (a few functions), I kept them as plain old functions, no need for an object with no members. While I was at it, I also deleted the kwargs, since the types are enough to prevent mistakes. Signed-off-by: Edward Z. Yang Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D24600805 Pulled By: ezyang fbshipit-source-id: 7e3ce8cb2446e3788f934ddcc17f7da6e9299511 --- tools/codegen/gen.py | 154 ++++++++++++++++++++++++------------------- 1 file changed, 86 insertions(+), 68 deletions(-) diff --git a/tools/codegen/gen.py b/tools/codegen/gen.py index 2b8f4fc64959..0ed2dff543fe 100644 --- a/tools/codegen/gen.py +++ b/tools/codegen/gen.py @@ -10,6 +10,7 @@ import pathlib import functools import json +from dataclasses import dataclass from tools.codegen.code_template import CodeTemplate from tools.codegen.model import * @@ -102,13 +103,25 @@ def parse_native_yaml(path: str) -> List[NativeFunction]: def with_native_function(func: Callable[[NativeFunction], T]) -> Callable[[NativeFunction], T]: @functools.wraps(func) def wrapper(f: NativeFunction) -> T: - with context(f'in {f.loc}:\n {f.func}'): - with local.parametrize( - use_c10_dispatcher=f.use_c10_dispatcher, - ): - return func(f) + with native_function_manager(f): + return func(f) return wrapper +def method_with_native_function(func: Callable[[S, NativeFunction], T]) -> Callable[[S, NativeFunction], T]: + @functools.wraps(func) + def wrapper(slf: S, f: NativeFunction) -> T: + with native_function_manager(f): + return func(slf, f) + return wrapper + +@contextlib.contextmanager +def native_function_manager(f: NativeFunction) -> Iterator[None]: + with context(f'in {f.loc}:\n {f.func}'): + with local.parametrize( + use_c10_dispatcher=f.use_c10_dispatcher, + ): + yield + # These two functions purposely return generators in analogy to map() # so that you don't mix up when you need to list() them @@ -180,49 +193,53 @@ def cpp_string(s: str) -> str: # # This function is also used for a secondary purpose: the registration # logic is also reused to implement per-operator registration. -def compute_type_method( - dispatch: Optional[str], *, +@dataclass(frozen=True) +class ComputeTypeMethod: + dispatch: Optional[str] + # TODO: Give more precise type Union[Literal[Target.DEFINITION, # Target.REGISTRATION]]; requires Literal from typing_extensions # which we don't have a dep for yet. - target: Target, + target: Target + # Selector object to determine which operators to generate # registration code for. selector: SelectiveBuilder -) -> Callable[[NativeFunction], Optional[str]]: - if dispatch is None: - assert target is Target.REGISTRATION + def __post_init__(self) -> None: + assert self.target is not Target.DECLARATION + if self.dispatch is None: + assert self.target is Target.REGISTRATION - @with_native_function - def func(f: NativeFunction) -> Optional[str]: - # Has to be here as mypy won't transfer asserts into closures - assert target is not Target.DECLARATION + @method_with_native_function + def __call__(self, f: NativeFunction) -> Optional[str]: + # for mypy type refinement; would be fixed by TODO on target + assert self.target is not Target.DECLARATION - if dispatch is not None: - if dispatch not in f.dispatch: + if self.dispatch is not None: + if self.dispatch not in f.dispatch: return None op_name = f"aten::{f.func.name}" - if target is Target.REGISTRATION and not selector.is_operator_selected(op_name): + if self.target is Target.REGISTRATION and not self.selector.is_operator_selected(op_name): return None name = native.name(f.func) returns_type = native.returns_type(f.func.returns) args = native.arguments(f.func) args_str = ', '.join(map(str, args)) - dispatch_to_all_backends = dispatch is not None and dispatch in KEYWORD_ALL_BACKENDS + dispatch_to_all_backends = self.dispatch is not None and self.dispatch in KEYWORD_ALL_BACKENDS - if target is Target.DEFINITION: - assert dispatch is not None - impl_name = f"at::native::{f.dispatch[dispatch]}" + if self.target is Target.DEFINITION: + assert self.dispatch is not None + impl_name = f"at::native::{f.dispatch[self.dispatch]}" args_exprs_str = ', '.join(a.name for a in args) return_kw = " return " cuda_guard = "" - if dispatch_to_all_backends or 'CUDA' in dispatch: + if dispatch_to_all_backends or 'CUDA' in self.dispatch: self_args = (a for a in f.func.arguments if a.name == "self") # There is precedence for which argument we use to do @@ -249,7 +266,7 @@ def func(f: NativeFunction) -> Optional[str]: # works just as well. if f.device_guard and dispatch_to_all_backends and has_tensor_options: cuda_guard = cuda_guard_from_tensor_options - elif f.device_guard and dispatch is not None and 'CUDA' in dispatch and has_tensor_options: + elif f.device_guard and self.dispatch is not None and 'CUDA' in self.dispatch and has_tensor_options: cuda_guard = f"""\ globalContext().lazyInitCUDA(); {cuda_guard_from_tensor_options} @@ -269,8 +286,8 @@ def func(f: NativeFunction) -> Optional[str]: }} """ - elif target is Target.REGISTRATION: - if dispatch is None: + elif self.target is Target.REGISTRATION: + if self.dispatch is None: return f'm.def({cpp_string(str(f.func))});\n' elif f.manual_kernel_registration: return None @@ -278,7 +295,7 @@ def func(f: NativeFunction) -> Optional[str]: if dispatch_to_all_backends: type_name = f'TypeDefault::{name}' else: - type_name = f'{dispatch}Type::{name}' + type_name = f'{self.dispatch}Type::{name}' dispatcher_sig = DispatcherSignature.from_schema(f.func) @@ -302,21 +319,22 @@ def func(f: NativeFunction) -> Optional[str]: # in a TORCH_LIBRARY_FRAGMENT that does not have an ambient backend. So # the torch::dispatch specification here is important! See # Note [Redundancy in registration code is OK] for how we handle redundant info. - if dispatch is not None: - payload = f"torch::dispatch(DispatchKey::{dispatch},\n{payload})\n" + if self.dispatch is not None: + payload = f"torch::dispatch(DispatchKey::{self.dispatch},\n{payload})\n" return f'm.impl("{f.func.name}",\n{payload});\n' else: - assert_never(target) - - return func + assert_never(self.target) # Generates Function.cpp and Function.h. These files provide the # functional public C++ API, and the scaffolding to call into # the dispatcher from these functions. See also compute_tensor_method. -def compute_function(*, target: Target) -> Callable[[NativeFunction], Optional[str]]: - @with_native_function - def go(f: NativeFunction) -> Optional[str]: +@dataclass(frozen=True) +class ComputeFunction: + target: Target + + @method_with_native_function + def __call__(self, f: NativeFunction) -> Optional[str]: if f.manual_kernel_registration: return None if Variant.function not in f.variants: @@ -326,13 +344,13 @@ def go(f: NativeFunction) -> Optional[str]: sig_group = CppSignatureGroup.from_schema(f.func, method=False) - if target is Target.DECLARATION: + if self.target is Target.DECLARATION: result = f"CAFFE2_API {sig_group.signature.decl()};\n" if sig_group.faithful_signature is not None: result += f"CAFFE2_API {sig_group.faithful_signature.decl()};\n" return result - assert target is Target.DEFINITION + assert self.target is Target.DEFINITION def generate_defn(sig: CppSignature) -> str: dispatcher_sig = DispatcherSignature.from_schema(f.func) @@ -357,14 +375,15 @@ def generate_defn(sig: CppSignature) -> str: return result - return go - # Generates TensorBody.h (sic) and TensorMethods.cpp. These files provide the # object-oriented (method-based) public C++ API, and the scaffolding to call into # the dispatcher from these functions. See also compute_function. -def compute_tensor_method(*, target: Target) -> Callable[[NativeFunction], Optional[str]]: - @with_native_function - def go(f: NativeFunction) -> Optional[str]: +@dataclass(frozen=True) +class ComputeTensorMethod: + target: Target + + @method_with_native_function + def __call__(self, f: NativeFunction) -> Optional[str]: if Variant.method not in f.variants: return None @@ -376,13 +395,13 @@ def go(f: NativeFunction) -> Optional[str]: sig_group = CppSignatureGroup.from_schema(f.func, method=True) - if target is Target.DECLARATION: + if self.target is Target.DECLARATION: result = f"{sig_group.signature.decl()} const;\n" if sig_group.faithful_signature is not None: result += f"{sig_group.faithful_signature.decl()} const;\n" return result - assert target is Target.DEFINITION + assert self.target is Target.DEFINITION def generate_defn(sig: CppSignature) -> str: dispatcher_sig = DispatcherSignature.from_schema(f.func) @@ -406,8 +425,6 @@ def generate_defn(sig: CppSignature) -> str: return result - return go - # Generates ATenOpList.cpp, a runtime accessible list of all aten # operators. # TODO: This was historically used to help some JIT interop code @@ -442,9 +459,12 @@ def compute_native_function_declaration(f: NativeFunction) -> List[str]: # Generates BackendSelectRegister.cpp, a series of kernels which provide # specialized computation of dispatch key for operator signatures which cannot # be easily done automatically using templating. -def compute_backend_select(*, target: Target) -> Callable[[NativeFunction], Optional[str]]: - @with_native_function - def go(f: NativeFunction) -> Optional[str]: +@dataclass(frozen=True) +class ComputeBackendSelect: + target: Target + + @method_with_native_function + def __call__(self, f: NativeFunction) -> Optional[str]: if str(f.func.name.name).endswith('_like') or str(f.func.name.name).startswith('new_'): return None @@ -471,7 +491,7 @@ def go(f: NativeFunction) -> Optional[str]: dispatcher_exprs = native_sig.dispatcher_exprs() dispatch_key = "options.computeDispatchKey()" - if target is Target.DEFINITION: + if self.target is Target.DEFINITION: # I don't think there's actually a good reason to generate # these two cases differently # The first case could probably be improved though- it calls dispatchTypeId(), @@ -494,7 +514,7 @@ def go(f: NativeFunction) -> Optional[str]: return op.callWithDispatchKey(_dk, {', '.join(a.expr for a in dispatcher_exprs)}); }} """ - elif target is Target.REGISTRATION: + elif self.target is Target.REGISTRATION: if local.use_c10_dispatcher() is UseC10Dispatcher.full: return f"""m.impl("aten::{f.func.name}", TORCH_FN({name}));""" elif local.use_c10_dispatcher() is UseC10Dispatcher.hacky_wrapper_for_legacy_signatures: @@ -504,11 +524,10 @@ def go(f: NativeFunction) -> Optional[str]: else: assert local.use_c10_dispatcher() is UseC10Dispatcher.with_codegenerated_unboxing_wrapper return f"""m.impl_UNBOXED("aten::{f.func.name}", {name});""" - elif target is Target.DECLARATION: + elif self.target is Target.DECLARATION: raise AssertionError() else: - assert_never(target) - return go + assert_never(self.target) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # @@ -993,12 +1012,11 @@ def make_file_manager(install_dir: str) -> FileManager: '', 'Backend': dispatch, 'type_derived_method_definitions': list(mapMaybe( - compute_type_method(dispatch, target=Target.DEFINITION, selector=selector), + ComputeTypeMethod(dispatch, Target.DEFINITION, selector), native_functions )), 'function_registrations': list(mapMaybe( - compute_type_method( - dispatch, target=Target.REGISTRATION, selector=selector), + ComputeTypeMethod(dispatch, Target.REGISTRATION, selector), native_functions )), }) @@ -1012,35 +1030,35 @@ def make_file_manager(install_dir: str) -> FileManager: cpu_fm.write('TypeDefault.cpp', lambda: { 'type_method_definitions': list(mapMaybe( - compute_type_method('Math', target=Target.DEFINITION, selector=selector), + ComputeTypeMethod('Math', Target.DEFINITION, selector), native_functions)) + list(mapMaybe( - compute_type_method('DefaultBackend', target=Target.DEFINITION, selector=selector), + ComputeTypeMethod('DefaultBackend', Target.DEFINITION, selector), native_functions)), 'function_registrations': list(mapMaybe( - compute_type_method(None, target=Target.REGISTRATION, selector=schema_selector), + ComputeTypeMethod(None, Target.REGISTRATION, schema_selector), native_functions)), 'math_function_registrations': list(mapMaybe( - compute_type_method('Math', target=Target.REGISTRATION, selector=selector), + ComputeTypeMethod('Math', Target.REGISTRATION, selector), native_functions)), 'default_backend_function_registrations': list(mapMaybe( - compute_type_method('DefaultBackend', target=Target.REGISTRATION, selector=selector), + ComputeTypeMethod('DefaultBackend', Target.REGISTRATION, selector), native_functions)), }) cpu_fm.write('Functions.h', lambda: { - 'function_declarations': list(mapMaybe(compute_function(target=Target.DECLARATION), native_functions)), + 'function_declarations': list(mapMaybe(ComputeFunction(Target.DECLARATION), native_functions)), }) cpu_fm.write('Functions.cpp', lambda: { - 'function_definitions': list(mapMaybe(compute_function(target=Target.DEFINITION), native_functions)), + 'function_definitions': list(mapMaybe(ComputeFunction(Target.DEFINITION), native_functions)), }) core_fm.write('TensorBody.h', lambda: { - 'tensor_method_declarations': list(mapMaybe(compute_tensor_method(target=Target.DECLARATION), native_functions)), + 'tensor_method_declarations': list(mapMaybe(ComputeTensorMethod(Target.DECLARATION), native_functions)), }) core_fm.write('TensorMethods.cpp', lambda: { - 'tensor_method_definitions': list(mapMaybe(compute_tensor_method(target=Target.DEFINITION), native_functions)), + 'tensor_method_definitions': list(mapMaybe(ComputeTensorMethod(Target.DEFINITION), native_functions)), }) core_fm.write('ATenOpList.cpp', lambda: { 'aten_ops': list(mapMaybe(compute_aten_op, native_functions)), @@ -1050,9 +1068,9 @@ def make_file_manager(install_dir: str) -> FileManager: }) cpu_fm.write('BackendSelectRegister.cpp', lambda: { 'backend_select_method_definitions': - list(mapMaybe(compute_backend_select(target=Target.DEFINITION), native_functions)), + list(mapMaybe(ComputeBackendSelect(Target.DEFINITION), native_functions)), 'backend_select_function_registrations': - list(mapMaybe(compute_backend_select(target=Target.REGISTRATION), native_functions)), + list(mapMaybe(ComputeBackendSelect(Target.REGISTRATION), native_functions)), }) cpu_fm.write('Declarations.yaml', lambda: format_yaml([compute_declaration_yaml(f) for f in native_functions])) From 2204374fd434df1fe5c38d2410da089857182563 Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Wed, 11 Nov 2020 10:35:04 -0800 Subject: [PATCH 11/93] Revert D24667127: [c10d] switch ProcessGroupNCCL:Options to be managed by intrusive_ptr Test Plan: revert-hammer Differential Revision: D24667127 (https://github.com/pytorch/pytorch/commit/ae5c2febb912066f1a8dec8b54451c09195b2c6d) Original commit changeset: 54986193ba1b fbshipit-source-id: 12e1ebea1981c0b1b6dff4c8a2e2045878d44537 --- torch/csrc/distributed/c10d/init.cpp | 12 +++++------- torch/lib/c10d/ProcessGroupNCCL.cpp | 6 +++--- torch/lib/c10d/ProcessGroupNCCL.hpp | 16 +++------------- .../lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp | 16 ++++++++-------- 4 files changed, 19 insertions(+), 31 deletions(-) diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index dd32ff91603c..e9d8f618eb21 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -1000,17 +1000,16 @@ that adds a prefix to each key inserted to the store. const c10::intrusive_ptr<::c10d::Store>&, int, int, - const c10::intrusive_ptr< - ::c10d::ProcessGroupNCCL::Options>&>(), + ::c10d::ProcessGroupNCCL::Options>(), py::call_guard()) .def( py::init([](const c10::intrusive_ptr<::c10d::Store>& store, int rank, int size, const std::chrono::milliseconds& timeout) { - auto options = ::c10d::ProcessGroupNCCL::Options::create(); - options->isHighPriorityStream = false; - options->opTimeout = timeout; + ::c10d::ProcessGroupNCCL::Options options; + options.isHighPriorityStream = false; + options.opTimeout = timeout; return std::make_shared<::c10d::ProcessGroupNCCL>( store, rank, size, options); }), @@ -1021,8 +1020,7 @@ that adds a prefix to each key inserted to the store. ::c10d::ProcessGroupNCCL::kProcessGroupNCCLOpTimeoutMillis), py::call_guard()); - intrusive_ptr_class_<::c10d::ProcessGroupNCCL::Options>( - processGroupNCCL, "Options") + py::class_<::c10d::ProcessGroupNCCL::Options>(processGroupNCCL, "Options") .def(py::init<>()) .def_readwrite( "is_high_priority", diff --git a/torch/lib/c10d/ProcessGroupNCCL.cpp b/torch/lib/c10d/ProcessGroupNCCL.cpp index 0b1a4c9f34e6..acb81d0cad6d 100644 --- a/torch/lib/c10d/ProcessGroupNCCL.cpp +++ b/torch/lib/c10d/ProcessGroupNCCL.cpp @@ -440,14 +440,14 @@ ProcessGroupNCCL::ProcessGroupNCCL( const c10::intrusive_ptr& store, int rank, int size, - const c10::intrusive_ptr& options) + Options options) : ProcessGroup(rank, size), store_(store), ncclCommCounter_(0), terminateProcessGroup_(false), - opTimeout_(options->opTimeout), + opTimeout_(options.opTimeout), futureNCCLCallbackStreams_(c10::cuda::device_count()), - isHighPriorityStream_(options->isHighPriorityStream) { + isHighPriorityStream_(options.isHighPriorityStream) { TORCH_CHECK(at::cuda::getNumGPUs() != 0, "ProcessGroupNCCL is only supported with GPUs, no GPUs found!"); blockingWait_ = parseEnvVarFlag(NCCL_BLOCKING_WAIT); diff --git a/torch/lib/c10d/ProcessGroupNCCL.hpp b/torch/lib/c10d/ProcessGroupNCCL.hpp index b84cc4deb051..b93bd0c2d70c 100644 --- a/torch/lib/c10d/ProcessGroupNCCL.hpp +++ b/torch/lib/c10d/ProcessGroupNCCL.hpp @@ -1,6 +1,5 @@ #pragma once -#include #include #include #include @@ -18,8 +17,6 @@ #include #include -#include - namespace c10d { // Environment variable which controls whether or not wait() is blocking or @@ -178,16 +175,9 @@ class ProcessGroupNCCL : public ProcessGroup { friend class ProcessGroupNCCL; }; - struct Options : torch::CustomClassHolder { + struct Options { explicit Options(); - // return intrusive_ptr of the object - static c10::intrusive_ptr create( - std::chrono::milliseconds timeout = kNoTimeout, - bool isHighStream = false) { - return c10::make_intrusive(); - } - std::chrono::milliseconds opTimeout; bool isHighPriorityStream; }; @@ -406,7 +396,7 @@ class ProcessGroupNCCL : public ProcessGroup { const c10::intrusive_ptr& store, int rank, int size, - const c10::intrusive_ptr& options = Options::create()); + Options options = Options()); // This constructor includes the deprecated `groupName` argument. // If you have existing code that uses the `groupName`, you can replace @@ -416,7 +406,7 @@ class ProcessGroupNCCL : public ProcessGroup { int rank, int size, const std::string& groupName, - const c10::intrusive_ptr& options = Options::create()) + Options options = Options()) : ProcessGroupNCCL(store, rank, size, options) {} virtual ~ProcessGroupNCCL(); diff --git a/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp b/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp index 82ca25049c63..e19981c523de 100644 --- a/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp +++ b/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp @@ -40,7 +40,7 @@ class ProcessGroupNCCLSimulateErrors : public c10d::ProcessGroupNCCL { const c10::intrusive_ptr& store, int rank, int size, - const c10::intrusive_ptr& opts) + c10d::ProcessGroupNCCL::Options opts) : ProcessGroupNCCL(store, rank, size, opts), simulate_error_(false) {} std::exception_ptr checkForNCCLErrors( @@ -109,7 +109,7 @@ class ProcessGroupNCCLTimedOutErrors : public ProcessGroupNCCLSimulateErrors { const c10::intrusive_ptr& store, int rank, int size, - const c10::intrusive_ptr& opts) + c10d::ProcessGroupNCCL::Options opts) : ProcessGroupNCCLSimulateErrors(store, rank, size, opts), set_timedout_error_(false) {} @@ -177,8 +177,8 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsBlocking) { } ASSERT_TRUE(setenv(c10d::NCCL_BLOCKING_WAIT, "1", 1) == 0); - auto options = c10d::ProcessGroupNCCL::Options::create(); - options->opTimeout = std::chrono::milliseconds(1000); + c10d::ProcessGroupNCCL::Options options; + options.opTimeout = std::chrono::milliseconds(1000); ProcessGroupNCCLSimulateErrors pg( store_, 0, 1, options); @@ -206,8 +206,8 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLTimedoutErrorsBlocking) { } ASSERT_TRUE(setenv(c10d::NCCL_BLOCKING_WAIT, "1", 1) == 0); - auto options = c10d::ProcessGroupNCCL::Options::create(); - options->opTimeout = std::chrono::milliseconds(3000); + c10d::ProcessGroupNCCL::Options options; + options.opTimeout = std::chrono::milliseconds(3000); ProcessGroupNCCLTimedOutErrors pg( store_, 0, 1, options); @@ -229,8 +229,8 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNonBlocking) { return; } - auto options = c10d::ProcessGroupNCCL::Options::create(); - options->opTimeout = std::chrono::milliseconds(3000); + c10d::ProcessGroupNCCL::Options options; + options.opTimeout = std::chrono::milliseconds(3000); ProcessGroupNCCLSimulateErrors pg( store_, 0, 1, options); From 1f946e942da498cdf3de621bb52f7b3d85fa2f0f Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Wed, 11 Nov 2020 10:42:13 -0800 Subject: [PATCH 12/93] Revert D24667128: [c10d] switch Store to be managed by intrusive_ptr Test Plan: revert-hammer Differential Revision: D24667128 (https://github.com/pytorch/pytorch/commit/0cfe3451d485f9e1ac2828cc2dd53d331f61ce24) Original commit changeset: 9b6024c31c85 fbshipit-source-id: d8ddf9eb2fccef5023e05698e0c4662708fe4945 --- test/cpp/rpc/e2e_test_base.h | 4 ++-- test/cpp_extensions/cpp_c10d_extension.cpp | 2 +- test/cpp_extensions/cpp_c10d_extension.hpp | 2 +- torch/csrc/distributed/c10d/init.cpp | 22 +++++++++---------- torch/csrc/distributed/rpc/init.cpp | 2 +- .../csrc/distributed/rpc/tensorpipe_agent.cpp | 2 +- torch/csrc/distributed/rpc/tensorpipe_agent.h | 2 +- torch/lib/c10d/PrefixStore.cpp | 2 +- torch/lib/c10d/PrefixStore.hpp | 6 ++--- torch/lib/c10d/ProcessGroupGloo.cpp | 6 ++--- torch/lib/c10d/ProcessGroupGloo.hpp | 2 +- torch/lib/c10d/ProcessGroupNCCL.cpp | 2 +- torch/lib/c10d/ProcessGroupNCCL.hpp | 8 +++---- torch/lib/c10d/Store.hpp | 4 +--- torch/lib/c10d/frontend.hpp | 4 ++-- torch/lib/c10d/test/FileStoreTest.cpp | 9 ++++---- torch/lib/c10d/test/HashStoreTest.cpp | 6 ++--- .../c10d/test/ProcessGroupGlooAsyncTest.cpp | 2 +- torch/lib/c10d/test/ProcessGroupGlooTest.cpp | 6 ++--- .../c10d/test/ProcessGroupNCCLErrorsTest.cpp | 8 +++---- torch/lib/c10d/test/ProcessGroupNCCLTest.cpp | 2 +- torch/lib/c10d/test/TCPStoreTest.cpp | 14 ++++++------ 22 files changed, 56 insertions(+), 61 deletions(-) diff --git a/test/cpp/rpc/e2e_test_base.h b/test/cpp/rpc/e2e_test_base.h index 114284839858..9d3ab71c0cfc 100644 --- a/test/cpp/rpc/e2e_test_base.h +++ b/test/cpp/rpc/e2e_test_base.h @@ -28,7 +28,7 @@ class TestE2EBase : public ::testing::Test { autogradContainer = getDistAutogradContainer(); // Setup server store. - store = c10::make_intrusive( + store = std::make_shared( serverAddress, 0, numWorkers, true, std::chrono::seconds(10)); buildRpcAgent(); @@ -147,7 +147,7 @@ class TestE2EBase : public ::testing::Test { std::shared_ptr rpcAgent; static const size_t numIters; static const size_t numWorkers; - c10::intrusive_ptr store; + std::shared_ptr store; static const char* serverAddress; }; diff --git a/test/cpp_extensions/cpp_c10d_extension.cpp b/test/cpp_extensions/cpp_c10d_extension.cpp index d5ba55a6379c..50e5f5861caa 100644 --- a/test/cpp_extensions/cpp_c10d_extension.cpp +++ b/test/cpp_extensions/cpp_c10d_extension.cpp @@ -108,7 +108,7 @@ c10::intrusive_ptr ProcessGroupTest::recvAnysource( } std::shared_ptr ProcessGroupTest::createProcessGroupTest( - const c10::intrusive_ptr<::c10d::Store>& store, + const std::shared_ptr<::c10d::Store>& store, int rank, int size, const std::chrono::duration& timeout) { diff --git a/test/cpp_extensions/cpp_c10d_extension.hpp b/test/cpp_extensions/cpp_c10d_extension.hpp index 1773953629d5..8aeec736d440 100644 --- a/test/cpp_extensions/cpp_c10d_extension.hpp +++ b/test/cpp_extensions/cpp_c10d_extension.hpp @@ -101,7 +101,7 @@ class ProcessGroupTest : public ProcessGroup { // Create a new ProcessGroupTest instance static std::shared_ptr createProcessGroupTest( - const c10::intrusive_ptr<::c10d::Store>& store, + const std::shared_ptr<::c10d::Store>& store, int rank, int size, const std::chrono::duration& timeout); diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index e9d8f618eb21..136efd32fc87 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -342,7 +342,7 @@ They are used in specifying strategies for reduction collectives, e.g., .def_readwrite("timeout", &::c10d::AllToAllOptions::timeout); auto store = - py::class_<::c10d::Store, c10::intrusive_ptr<::c10d::Store>, PythonStore>( + py::class_<::c10d::Store, std::shared_ptr<::c10d::Store>, PythonStore>( module, "Store", R"( @@ -546,7 +546,7 @@ Example:: >>> store.wait(["bad_key"], timedelta(seconds=10)) )"); - intrusive_ptr_class_<::c10d::FileStore>( + shared_ptr_class_<::c10d::FileStore>( module, "FileStore", store, @@ -569,7 +569,7 @@ Example:: .def(py::init()); #ifndef _WIN32 - intrusive_ptr_class_<::c10d::HashStore>( + shared_ptr_class_<::c10d::HashStore>( module, "HashStore", store, @@ -586,7 +586,7 @@ Example:: )") .def(py::init<>()); - intrusive_ptr_class_<::c10d::TCPStore>( + shared_ptr_class_<::c10d::TCPStore>( module, "TCPStore", store, @@ -626,7 +626,7 @@ Example:: std::chrono::milliseconds(::c10d::Store::kDefaultTimeout)); #endif - intrusive_ptr_class_<::c10d::PrefixStore>( + shared_ptr_class_<::c10d::PrefixStore>( module, "PrefixStore", store, @@ -639,7 +639,7 @@ that adds a prefix to each key inserted to the store. prefix (str): The prefix string that is prepended to each key before being inserted into the store. store (torch.distributed.store): A store object that forms the underlying key-value store. )") - .def(py::init>()); + .def(py::init>()); auto processGroup = shared_ptr_class_<::c10d::ProcessGroup>(module, "ProcessGroup") @@ -952,13 +952,13 @@ that adds a prefix to each key inserted to the store. processGroupGloo .def( py::init< - const c10::intrusive_ptr<::c10d::Store>&, + const std::shared_ptr<::c10d::Store>&, int, int, ::c10d::ProcessGroupGloo::Options>(), py::call_guard()) .def( - py::init([](const c10::intrusive_ptr<::c10d::Store>& store, + py::init([](const std::shared_ptr<::c10d::Store>& store, int rank, int size, std::chrono::milliseconds timeout) { @@ -997,13 +997,13 @@ that adds a prefix to each key inserted to the store. module, "ProcessGroupNCCL", processGroup) .def( py::init< - const c10::intrusive_ptr<::c10d::Store>&, + const std::shared_ptr<::c10d::Store>&, int, int, ::c10d::ProcessGroupNCCL::Options>(), py::call_guard()) .def( - py::init([](const c10::intrusive_ptr<::c10d::Store>& store, + py::init([](const std::shared_ptr<::c10d::Store>& store, int rank, int size, const std::chrono::milliseconds& timeout) { @@ -1168,7 +1168,7 @@ that adds a prefix to each key inserted to the store. // Python side of the world. Calling Python functions on a Python object // completely bypasses pybind11. We need to test that the overloaded // functions call into Python and behave like we expect. - [](c10::intrusive_ptr<::c10d::Store> store) { + [](std::shared_ptr<::c10d::Store> store) { auto add = [&store](const std::string& key, int64_t value) { store->add(key, value); }; diff --git a/torch/csrc/distributed/rpc/init.cpp b/torch/csrc/distributed/rpc/init.cpp index 81af4abebd5f..1d82a619ed7e 100644 --- a/torch/csrc/distributed/rpc/init.cpp +++ b/torch/csrc/distributed/rpc/init.cpp @@ -576,7 +576,7 @@ PyObject* rpc_init(PyObject* _unused, PyObject* noargs) { shared_ptr_class_(module, "TensorPipeAgent", rpcAgent) .def( - py::init([](const c10::intrusive_ptr<::c10d::Store>& store, + py::init([](const std::shared_ptr<::c10d::Store>& store, std::string selfName, worker_id_t selfId, int worldSize, diff --git a/torch/csrc/distributed/rpc/tensorpipe_agent.cpp b/torch/csrc/distributed/rpc/tensorpipe_agent.cpp index eff1e7ebdf21..6bf65f4c2628 100644 --- a/torch/csrc/distributed/rpc/tensorpipe_agent.cpp +++ b/torch/csrc/distributed/rpc/tensorpipe_agent.cpp @@ -220,7 +220,7 @@ void TensorPipeAgent::collectNames() { } TensorPipeAgent::TensorPipeAgent( - const c10::intrusive_ptr<::c10d::Store>& store, + const std::shared_ptr<::c10d::Store>& store, std::string selfName, worker_id_t selfId, int worldSize, diff --git a/torch/csrc/distributed/rpc/tensorpipe_agent.h b/torch/csrc/distributed/rpc/tensorpipe_agent.h index b8c9a8c64e5c..b4a500de65be 100644 --- a/torch/csrc/distributed/rpc/tensorpipe_agent.h +++ b/torch/csrc/distributed/rpc/tensorpipe_agent.h @@ -141,7 +141,7 @@ struct AggregatedNetworkData { class TensorPipeAgent : public RpcAgent { public: TensorPipeAgent( - const c10::intrusive_ptr<::c10d::Store>& store, + const std::shared_ptr<::c10d::Store>& store, std::string selfName, worker_id_t selfId, int worldSize, diff --git a/torch/lib/c10d/PrefixStore.cpp b/torch/lib/c10d/PrefixStore.cpp index 6f71e422bd0e..5f9a3c9c21ec 100644 --- a/torch/lib/c10d/PrefixStore.cpp +++ b/torch/lib/c10d/PrefixStore.cpp @@ -4,7 +4,7 @@ namespace c10d { PrefixStore::PrefixStore( const std::string& prefix, - c10::intrusive_ptr store) + std::shared_ptr store) : prefix_(prefix), store_(store) {} std::string PrefixStore::joinKey(const std::string& key) { diff --git a/torch/lib/c10d/PrefixStore.hpp b/torch/lib/c10d/PrefixStore.hpp index ec50b3b719bf..cad7112fbd76 100644 --- a/torch/lib/c10d/PrefixStore.hpp +++ b/torch/lib/c10d/PrefixStore.hpp @@ -7,9 +7,7 @@ namespace c10d { class PrefixStore : public Store { public: - explicit PrefixStore( - const std::string& prefix, - c10::intrusive_ptr store); + explicit PrefixStore(const std::string& prefix, std::shared_ptr store); virtual ~PrefixStore(){}; @@ -33,7 +31,7 @@ class PrefixStore : public Store { protected: std::string prefix_; - c10::intrusive_ptr store_; + std::shared_ptr store_; std::string joinKey(const std::string& key); std::vector joinKeys(const std::vector& keys); diff --git a/torch/lib/c10d/ProcessGroupGloo.cpp b/torch/lib/c10d/ProcessGroupGloo.cpp index 22da878cce43..90c9b695de28 100644 --- a/torch/lib/c10d/ProcessGroupGloo.cpp +++ b/torch/lib/c10d/ProcessGroupGloo.cpp @@ -108,7 +108,7 @@ namespace { // Wrap c10d store as Gloo store class GlooStore : public ::gloo::rendezvous::Store { public: - GlooStore(const c10::intrusive_ptr<::c10d::Store>& store) : store_(store) {} + GlooStore(const std::shared_ptr<::c10d::Store>& store) : store_(store) {} void set(const std::string& key, const std::vector& value) override { std::vector tmp(value.begin(), value.end()); @@ -131,7 +131,7 @@ class GlooStore : public ::gloo::rendezvous::Store { } protected: - c10::intrusive_ptr<::c10d::Store> store_; + std::shared_ptr<::c10d::Store> store_; }; typedef void (*ReduceFunc)(void*, const void*, const void*, size_t); @@ -562,7 +562,7 @@ std::shared_ptr<::gloo::transport::Device> ProcessGroupGloo:: #endif ProcessGroupGloo::ProcessGroupGloo( - const c10::intrusive_ptr& store, + const std::shared_ptr& store, int rank, int size, Options options) diff --git a/torch/lib/c10d/ProcessGroupGloo.hpp b/torch/lib/c10d/ProcessGroupGloo.hpp index 0508b6f857a1..74fd0f6e5165 100644 --- a/torch/lib/c10d/ProcessGroupGloo.hpp +++ b/torch/lib/c10d/ProcessGroupGloo.hpp @@ -152,7 +152,7 @@ class ProcessGroupGloo : public ProcessGroup { static std::shared_ptr<::gloo::transport::Device> createDefaultDevice(); explicit ProcessGroupGloo( - const c10::intrusive_ptr& store, + const std::shared_ptr& store, int rank, int size, Options options = Options()); diff --git a/torch/lib/c10d/ProcessGroupNCCL.cpp b/torch/lib/c10d/ProcessGroupNCCL.cpp index acb81d0cad6d..bd1563226343 100644 --- a/torch/lib/c10d/ProcessGroupNCCL.cpp +++ b/torch/lib/c10d/ProcessGroupNCCL.cpp @@ -437,7 +437,7 @@ bool ProcessGroupNCCL::WorkNCCL::timedOut() { } ProcessGroupNCCL::ProcessGroupNCCL( - const c10::intrusive_ptr& store, + const std::shared_ptr& store, int rank, int size, Options options) diff --git a/torch/lib/c10d/ProcessGroupNCCL.hpp b/torch/lib/c10d/ProcessGroupNCCL.hpp index b93bd0c2d70c..59f06fda1ec1 100644 --- a/torch/lib/c10d/ProcessGroupNCCL.hpp +++ b/torch/lib/c10d/ProcessGroupNCCL.hpp @@ -163,7 +163,7 @@ class ProcessGroupNCCL : public ProcessGroup { // Reference to the store so that we can write aborted communicators // to the store. - c10::intrusive_ptr store_; + std::shared_ptr store_; // Store a reference to NCCL collective's outputs to be used by getFuture. std::shared_ptr> outputs_; @@ -393,7 +393,7 @@ class ProcessGroupNCCL : public ProcessGroup { // communicator. These NCCL communicators are cached and reused if possible. // ProcessGroupNCCL( - const c10::intrusive_ptr& store, + const std::shared_ptr& store, int rank, int size, Options options = Options()); @@ -402,7 +402,7 @@ class ProcessGroupNCCL : public ProcessGroup { // If you have existing code that uses the `groupName`, you can replace // it by specifying a `c10d::PrefixStore(groupName, store)` for store. C10_DEPRECATED ProcessGroupNCCL( - const c10::intrusive_ptr& store, + const std::shared_ptr& store, int rank, int size, const std::string& groupName, @@ -594,7 +594,7 @@ class ProcessGroupNCCL : public ProcessGroup { static const int64_t kWorkCleanupThreadSleepMillis; // The store is used to broadcast the NCCL unique ID of rank 0. - c10::intrusive_ptr store_; + std::shared_ptr store_; // The number of NCCL communicators that have been created during // the lifetime of this process group. This sequence number is diff --git a/torch/lib/c10d/Store.hpp b/torch/lib/c10d/Store.hpp index f97e80013cdb..e42bbf300e0b 100644 --- a/torch/lib/c10d/Store.hpp +++ b/torch/lib/c10d/Store.hpp @@ -6,11 +6,9 @@ #include #include -#include - namespace c10d { -class Store : public torch::CustomClassHolder { +class Store { public: static constexpr std::chrono::milliseconds kDefaultTimeout = std::chrono::seconds(300); diff --git a/torch/lib/c10d/frontend.hpp b/torch/lib/c10d/frontend.hpp index 3449ee30b5ef..69705427b53c 100644 --- a/torch/lib/c10d/frontend.hpp +++ b/torch/lib/c10d/frontend.hpp @@ -35,7 +35,7 @@ class DistributedC10d { const std::chrono::milliseconds& timeout, int64_t world_size, int64_t rank, - c10::intrusive_ptr store, + std::shared_ptr store, const std::string& group_name); void destroyProcessGroup(std::shared_ptr group); @@ -202,7 +202,7 @@ class DistributedC10d { // need to use ProcessGroup or ProcesGroup* as key. std::unordered_map< std::shared_ptr, - std::pair>> + std::pair>> pg_map_; // Note, this is different mapping relationship than original Python diff --git a/torch/lib/c10d/test/FileStoreTest.cpp b/torch/lib/c10d/test/FileStoreTest.cpp index ce75c78adce7..cc8da6326091 100644 --- a/torch/lib/c10d/test/FileStoreTest.cpp +++ b/torch/lib/c10d/test/FileStoreTest.cpp @@ -41,7 +41,7 @@ std::string tmppath() { void testGetSet(std::string path, std::string prefix = "") { // Basic Set/Get on File Store { - auto fileStore = c10::make_intrusive(path, 2); + auto fileStore = std::make_shared(path, 2); c10d::PrefixStore store(prefix, fileStore); c10d::test::set(store, "key0", "value0"); c10d::test::set(store, "key1", "value1"); @@ -53,7 +53,7 @@ void testGetSet(std::string path, std::string prefix = "") { // Perform get on new instance { - auto fileStore = c10::make_intrusive(path, 2); + auto fileStore = std::make_shared(path, 2); c10d::PrefixStore store(prefix, fileStore); c10d::test::check(store, "key0", "value0"); } @@ -69,8 +69,7 @@ void stressTestStore(std::string path, std::string prefix = "") { for (auto i = 0; i < numThreads; i++) { threads.push_back(std::thread([&] { - auto fileStore = - c10::make_intrusive(path, numThreads + 1); + auto fileStore = std::make_shared(path, numThreads + 1); c10d::PrefixStore store(prefix, fileStore); sem1.post(); sem2.wait(); @@ -88,7 +87,7 @@ void stressTestStore(std::string path, std::string prefix = "") { // Check that the counter has the expected value { - auto fileStore = c10::make_intrusive(path, numThreads + 1); + auto fileStore = std::make_shared(path, numThreads + 1); c10d::PrefixStore store(prefix, fileStore); std::string expected = std::to_string(numThreads * numIterations); c10d::test::check(store, "counter", expected); diff --git a/torch/lib/c10d/test/HashStoreTest.cpp b/torch/lib/c10d/test/HashStoreTest.cpp index 24b7fc76a417..a16f83231a58 100644 --- a/torch/lib/c10d/test/HashStoreTest.cpp +++ b/torch/lib/c10d/test/HashStoreTest.cpp @@ -11,7 +11,7 @@ void testGetSet(std::string prefix = "") { // Basic set/get { - auto hashStore = c10::make_intrusive(); + auto hashStore = std::make_shared(); c10d::PrefixStore store(prefix, hashStore); c10d::test::set(store, "key0", "value0"); c10d::test::set(store, "key1", "value1"); @@ -32,7 +32,7 @@ void testGetSet(std::string prefix = "") { // get() waits up to timeout_. { - auto hashStore = c10::make_intrusive(); + auto hashStore = std::make_shared(); c10d::PrefixStore store(prefix, hashStore); std::thread th([&]() { c10d::test::set(store, "key0", "value0"); }); c10d::test::check(store, "key0", "value0"); @@ -47,7 +47,7 @@ void stressTestStore(std::string prefix = "") { std::vector threads; c10d::test::Semaphore sem1, sem2; - auto hashStore = c10::make_intrusive(); + auto hashStore = std::make_shared(); c10d::PrefixStore store(prefix, hashStore); for (auto i = 0; i < numThreads; i++) { diff --git a/torch/lib/c10d/test/ProcessGroupGlooAsyncTest.cpp b/torch/lib/c10d/test/ProcessGroupGlooAsyncTest.cpp index 091ea9b2ad07..1363a842eab3 100644 --- a/torch/lib/c10d/test/ProcessGroupGlooAsyncTest.cpp +++ b/torch/lib/c10d/test/ProcessGroupGlooAsyncTest.cpp @@ -45,7 +45,7 @@ class AsyncTest { } void start(int rank, int size) { - auto store = c10::make_intrusive<::c10d::FileStore>(path_, size); + auto store = std::make_shared<::c10d::FileStore>(path_, size); // Use tiny timeout to make this test run fast ::c10d::ProcessGroupGloo::Options options; diff --git a/torch/lib/c10d/test/ProcessGroupGlooTest.cpp b/torch/lib/c10d/test/ProcessGroupGlooTest.cpp index 469cf32a8442..de993a1110b4 100644 --- a/torch/lib/c10d/test/ProcessGroupGlooTest.cpp +++ b/torch/lib/c10d/test/ProcessGroupGlooTest.cpp @@ -45,7 +45,7 @@ class SignalTest { } c10::intrusive_ptr<::c10d::ProcessGroup::Work> run(int rank, int size) { - auto store = c10::make_intrusive<::c10d::FileStore>(path_, size); + auto store = std::make_shared<::c10d::FileStore>(path_, size); ::c10d::ProcessGroupGloo::Options options; // Set a timeout that is small enough to make this test run fast, but also @@ -101,7 +101,7 @@ c10::intrusive_ptr<::c10d::ProcessGroup::Work> testSignal( class ProcessGroupGlooDelayed : public ::c10d::ProcessGroupGloo { public: ProcessGroupGlooDelayed( - const c10::intrusive_ptr<::c10d::Store>& store, + const std::shared_ptr<::c10d::Store>& store, int rank, int size, Options options) @@ -151,7 +151,7 @@ class CollectiveTest { } void start(int rank, int size, bool delayed) { - auto store = c10::make_intrusive<::c10d::FileStore>(path_, size); + auto store = std::make_shared<::c10d::FileStore>(path_, size); // Set a timeout that is small enough to make this test run fast, but also // make sure that we don't get timeouts in the ProcessGroupGloo constructor. diff --git a/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp b/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp index e19981c523de..f1348922e126 100644 --- a/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp +++ b/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp @@ -37,7 +37,7 @@ class WorkNCCLSimulateErrors : public c10d::ProcessGroupNCCL::WorkNCCL { class ProcessGroupNCCLSimulateErrors : public c10d::ProcessGroupNCCL { public: ProcessGroupNCCLSimulateErrors( - const c10::intrusive_ptr& store, + const std::shared_ptr& store, int rank, int size, c10d::ProcessGroupNCCL::Options opts) @@ -106,7 +106,7 @@ class WorkNCCLTimedoutErrors : public c10d::ProcessGroupNCCL::WorkNCCL { class ProcessGroupNCCLTimedOutErrors : public ProcessGroupNCCLSimulateErrors { public: ProcessGroupNCCLTimedOutErrors( - const c10::intrusive_ptr& store, + const std::shared_ptr& store, int rank, int size, c10d::ProcessGroupNCCL::Options opts) @@ -153,7 +153,7 @@ class ProcessGroupNCCLErrorsTest : public ::testing::Test { void SetUp() override { size_t numDevices = cudaNumDevices(); TemporaryFile file; - store_ = c10::make_intrusive<::c10d::FileStore>(file.path, 1); + store_ = std::make_shared<::c10d::FileStore>(file.path, 1); at::cuda::OptionalCUDAGuard deviceGuard; tensors_.resize(numDevices); @@ -168,7 +168,7 @@ class ProcessGroupNCCLErrorsTest : public ::testing::Test { } std::vector tensors_; - c10::intrusive_ptr<::c10d::FileStore> store_; + std::shared_ptr<::c10d::FileStore> store_; }; TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsBlocking) { diff --git a/torch/lib/c10d/test/ProcessGroupNCCLTest.cpp b/torch/lib/c10d/test/ProcessGroupNCCLTest.cpp index fa5e988273fc..efa96312aba0 100644 --- a/torch/lib/c10d/test/ProcessGroupNCCLTest.cpp +++ b/torch/lib/c10d/test/ProcessGroupNCCLTest.cpp @@ -31,7 +31,7 @@ class NCCLTestBase { } void initialize(int rank, int size) { - auto store = c10::make_intrusive<::c10d::FileStore>(path_, size); + auto store = std::make_shared<::c10d::FileStore>(path_, size); pg_ = std::unique_ptr<::c10d::ProcessGroupNCCL>( new ::c10d::ProcessGroupNCCL(store, rank, size)); diff --git a/torch/lib/c10d/test/TCPStoreTest.cpp b/torch/lib/c10d/test/TCPStoreTest.cpp index 8073ec0345e0..0cfa72c7801a 100644 --- a/torch/lib/c10d/test/TCPStoreTest.cpp +++ b/torch/lib/c10d/test/TCPStoreTest.cpp @@ -16,7 +16,7 @@ void testHelper(const std::string& prefix = "") { const auto numThreads = 16; const auto numWorkers = numThreads + 1; - auto serverTCPStore = c10::make_intrusive( + auto serverTCPStore = std::make_shared( "127.0.0.1", 0, numWorkers, @@ -25,7 +25,7 @@ void testHelper(const std::string& prefix = "") { /* wait */ false); auto serverStore = - c10::make_intrusive(prefix, serverTCPStore); + std::make_unique(prefix, serverTCPStore); // server store auto serverThread = std::thread([&serverStore, &serverTCPStore] { // Wait for all workers to join. @@ -64,13 +64,13 @@ void testHelper(const std::string& prefix = "") { c10d::test::Semaphore sem1, sem2; // Each thread will have a client store to send/recv data - std::vector> clientTCPStores; - std::vector> clientStores; + std::vector> clientTCPStores; + std::vector> clientStores; for (auto i = 0; i < numThreads; i++) { - clientTCPStores.push_back(c10::make_intrusive( + clientTCPStores.push_back(std::make_unique( "127.0.0.1", serverTCPStore->getPort(), numWorkers, false)); - clientStores.push_back( - c10::make_intrusive(prefix, clientTCPStores[i])); + clientStores.push_back(std::unique_ptr( + new c10d::PrefixStore(prefix, clientTCPStores[i]))); } std::string expectedCounterRes = std::to_string(numThreads * numIterations + 1); From dac0192148a5336fa066cc39a23a8a0c2f236584 Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Wed, 11 Nov 2020 10:42:13 -0800 Subject: [PATCH 13/93] Revert D23632280: [c10d] switch ProcessGroup::Work to be managed by intrusive_ptr Test Plan: revert-hammer Differential Revision: D23632280 (https://github.com/pytorch/pytorch/commit/0650a6166ff42a65b431f527d0f9a76f5be44e37) Original commit changeset: 0a4642a8ffab fbshipit-source-id: 2aa8ddb874fab11f773f4c08d740afcd865482e9 --- test/cpp_extensions/cpp_c10d_extension.cpp | 32 +++--- test/cpp_extensions/cpp_c10d_extension.hpp | 26 ++--- torch/csrc/distributed/c10d/init.cpp | 5 +- .../distributed/rpc/process_group_agent.cpp | 2 +- .../distributed/rpc/process_group_agent.h | 4 +- torch/lib/c10d/ProcessGroup.cpp | 2 +- torch/lib/c10d/ProcessGroup.hpp | 35 +++--- torch/lib/c10d/ProcessGroupGloo.cpp | 101 +++++++++--------- torch/lib/c10d/ProcessGroupGloo.hpp | 38 +++---- torch/lib/c10d/ProcessGroupMPI.cpp | 42 ++++---- torch/lib/c10d/ProcessGroupMPI.hpp | 36 +++---- torch/lib/c10d/ProcessGroupNCCL.cpp | 52 ++++----- torch/lib/c10d/ProcessGroupNCCL.hpp | 46 ++++---- torch/lib/c10d/ProcessGroupRoundRobin.cpp | 30 +++--- torch/lib/c10d/ProcessGroupRoundRobin.hpp | 30 +++--- torch/lib/c10d/comm.cpp | 4 +- torch/lib/c10d/example/allreduce.cpp | 2 +- torch/lib/c10d/reducer.cpp | 2 +- torch/lib/c10d/reducer.hpp | 9 +- .../c10d/test/ProcessGroupGlooAsyncTest.cpp | 10 +- torch/lib/c10d/test/ProcessGroupGlooTest.cpp | 16 +-- torch/lib/c10d/test/ProcessGroupMPITest.cpp | 38 ++++--- .../c10d/test/ProcessGroupNCCLErrorsTest.cpp | 8 +- torch/lib/c10d/test/ProcessGroupNCCLTest.cpp | 12 +-- 24 files changed, 287 insertions(+), 295 deletions(-) diff --git a/test/cpp_extensions/cpp_c10d_extension.cpp b/test/cpp_extensions/cpp_c10d_extension.cpp index 50e5f5861caa..b4901cdbcf4d 100644 --- a/test/cpp_extensions/cpp_c10d_extension.cpp +++ b/test/cpp_extensions/cpp_c10d_extension.cpp @@ -23,85 +23,85 @@ ProcessGroupTest::ProcessGroupTest(int rank, int size) ProcessGroupTest::~ProcessGroupTest() {} -c10::intrusive_ptr ProcessGroupTest::broadcast( +std::shared_ptr ProcessGroupTest::broadcast( std::vector& tensors, const BroadcastOptions& opts) { - return c10::make_intrusive(); + return std::make_shared(); } -c10::intrusive_ptr ProcessGroupTest::allreduce( +std::shared_ptr ProcessGroupTest::allreduce( std::vector& tensors, const AllreduceOptions& opts) { - return c10::make_intrusive(); + return std::make_shared(); } -c10::intrusive_ptr ProcessGroupTest::allreduce_coalesced( +std::shared_ptr ProcessGroupTest::allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts) { throw std::runtime_error("ProcessGroupTest does not support allreduce_coalesced"); } -c10::intrusive_ptr ProcessGroupTest::reduce( +std::shared_ptr ProcessGroupTest::reduce( std::vector& tensors, const ReduceOptions& opts) { throw std::runtime_error("ProcessGroupTest does not support reduce"); } -c10::intrusive_ptr ProcessGroupTest::allgather( +std::shared_ptr ProcessGroupTest::allgather( std::vector>& outputTensors, std::vector& inputTensors, const AllgatherOptions& opts) { throw std::runtime_error("ProcessGroupTest does not support allgather"); } -c10::intrusive_ptr ProcessGroupTest::allgather_base( +std::shared_ptr ProcessGroupTest::allgather_base( at::Tensor& outputBuffer, at::Tensor& inputBuffer, const AllgatherOptions& opts) { throw std::runtime_error("ProcessGroupTest does not support allgather_base"); } -c10::intrusive_ptr ProcessGroupTest::barrier( +std::shared_ptr ProcessGroupTest::barrier( const BarrierOptions& opts) { - return c10::make_intrusive(); + return std::make_shared(); } -c10::intrusive_ptr ProcessGroupTest::gather( +std::shared_ptr ProcessGroupTest::gather( std::vector>& outputTensors, std::vector& inputTensors, const GatherOptions& opts) { throw std::runtime_error("ProcessGroupTest does not support gather"); } -c10::intrusive_ptr ProcessGroupTest::scatter( +std::shared_ptr ProcessGroupTest::scatter( std::vector& outputTensors, std::vector>& inputTensors, const ScatterOptions& opts) { throw std::runtime_error("ProcessGroupTest does not support scatter"); } -c10::intrusive_ptr ProcessGroupTest::reduce_scatter( +std::shared_ptr ProcessGroupTest::reduce_scatter( std::vector& outputTensors, std::vector>& inputTensors, const ReduceScatterOptions& opts) { throw std::runtime_error("ProcessGroupTest does not support reduce_scatter"); } -c10::intrusive_ptr ProcessGroupTest::send( +std::shared_ptr ProcessGroupTest::send( std::vector& tensors, int dstRank, int tag) { throw std::runtime_error("ProcessGroupTest does not support send"); } -c10::intrusive_ptr ProcessGroupTest::recv( +std::shared_ptr ProcessGroupTest::recv( std::vector& tensors, int srcRank, int tag) { throw std::runtime_error("ProcessGroupTest does not support recv"); } -c10::intrusive_ptr ProcessGroupTest::recvAnysource( +std::shared_ptr ProcessGroupTest::recvAnysource( std::vector& tensor, int tag) { throw std::runtime_error("ProcessGroupTest does not support recvAnysource"); diff --git a/test/cpp_extensions/cpp_c10d_extension.hpp b/test/cpp_extensions/cpp_c10d_extension.hpp index 8aeec736d440..d8dffcd20327 100644 --- a/test/cpp_extensions/cpp_c10d_extension.hpp +++ b/test/cpp_extensions/cpp_c10d_extension.hpp @@ -41,61 +41,61 @@ class ProcessGroupTest : public ProcessGroup { explicit ProcessGroupTest(int rank = -1, int size = -1); virtual ~ProcessGroupTest(); - c10::intrusive_ptr broadcast( + std::shared_ptr broadcast( std::vector& data, const BroadcastOptions& opts = BroadcastOptions()) override; - c10::intrusive_ptr allreduce( + std::shared_ptr allreduce( std::vector& tensors, const AllreduceOptions& opts = AllreduceOptions()) override; - c10::intrusive_ptr allreduce_coalesced( + std::shared_ptr allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts = AllreduceCoalescedOptions()) override; - c10::intrusive_ptr reduce( + std::shared_ptr reduce( std::vector& tensors, const ReduceOptions& opts = ReduceOptions()) override; - c10::intrusive_ptr allgather( + std::shared_ptr allgather( std::vector>& outputTensors, std::vector& inputTensors, const AllgatherOptions& opts = AllgatherOptions()) override; - c10::intrusive_ptr allgather_base( + std::shared_ptr allgather_base( at::Tensor& outputBuffer, at::Tensor& inputBuffer, const AllgatherOptions& opts = AllgatherOptions()) override; - c10::intrusive_ptr barrier( + std::shared_ptr barrier( const BarrierOptions& opts = BarrierOptions()) override; - c10::intrusive_ptr gather( + std::shared_ptr gather( std::vector>& outputTensors, std::vector& inputTensors, const GatherOptions& opts = GatherOptions()) override; - c10::intrusive_ptr scatter( + std::shared_ptr scatter( std::vector& outputTensors, std::vector>& inputTensors, const ScatterOptions& opts = ScatterOptions()) override; - c10::intrusive_ptr reduce_scatter( + std::shared_ptr reduce_scatter( std::vector& outputTensors, std::vector>& inputTensors, const ReduceScatterOptions& opts = ReduceScatterOptions()) override; - c10::intrusive_ptr send( + std::shared_ptr send( std::vector& tensors, int dstRank, int tag); - c10::intrusive_ptr recv( + std::shared_ptr recv( std::vector& tensors, int srcRank, int tag); - c10::intrusive_ptr recvAnysource( + std::shared_ptr recvAnysource( std::vector& tensor, int tag); diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 136efd32fc87..d9ddf35ee1df 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -1,6 +1,5 @@ #include -#include #include #ifndef _WIN32 #include @@ -60,8 +59,6 @@ constexpr auto kDeprecationWarning = "{} API is being deprecated, please ping " "https://github.com/pytorch/pytorch/issues/46291 " "if you see this warning"; -template -using intrusive_ptr_class_ = py::class_>; // PythonStore is a pybind11 trampoline class to allow a Python // class to inherit from c10d.Store and implement its interface. @@ -1048,7 +1045,7 @@ that adds a prefix to each key inserted to the store. py::call_guard()); #endif - intrusive_ptr_class_<::c10d::ProcessGroup::Work>(module, "Work") + shared_ptr_class_<::c10d::ProcessGroup::Work>(module, "Work") .def("is_completed", &::c10d::ProcessGroup::Work::isCompleted) .def( "is_success", diff --git a/torch/csrc/distributed/rpc/process_group_agent.cpp b/torch/csrc/distributed/rpc/process_group_agent.cpp index 13e685b8fe74..2f29adc8f0c4 100644 --- a/torch/csrc/distributed/rpc/process_group_agent.cpp +++ b/torch/csrc/distributed/rpc/process_group_agent.cpp @@ -398,7 +398,7 @@ void ProcessGroupAgent::handleSend(const SendWork& work) { // ProcessGroup is not thread-safe when sending with the same tag, // hence the lock - std::vector> pendingSends; + std::vector> pendingSends; const auto dst = work.to_.id_; // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) diff --git a/torch/csrc/distributed/rpc/process_group_agent.h b/torch/csrc/distributed/rpc/process_group_agent.h index 70fb1b40244d..1bc8db9ebf20 100644 --- a/torch/csrc/distributed/rpc/process_group_agent.h +++ b/torch/csrc/distributed/rpc/process_group_agent.h @@ -230,14 +230,14 @@ class TORCH_API ProcessGroupAgent : public RpcAgent { // Lock and shared ptr to currently pending work, set in listenloop() and // interruptible in shutdown(). std::mutex recvWorkMutex_; - c10::intrusive_ptr recvWork_; + std::shared_ptr recvWork_; // Map of dst rank to current oustanding sends that we are waiting on. In the // case of a call to ::shutdown() while we are still waiting on these sends, // the pending sends contained in this map will be aborted, allowing the // waiting thread to be unblocked. std::unordered_map< worker_id_t, - std::set>> + std::set>> currentPendingSends_; // Lock to serialize access to the above map. std::mutex pendingSendMutex_; diff --git a/torch/lib/c10d/ProcessGroup.cpp b/torch/lib/c10d/ProcessGroup.cpp index 1d0d451f21a9..3521ed42c840 100644 --- a/torch/lib/c10d/ProcessGroup.cpp +++ b/torch/lib/c10d/ProcessGroup.cpp @@ -164,7 +164,7 @@ ProcessGroup::~ProcessGroup() {} // This is introduced so that implementors of ProcessGroup would not need to // have this implmentation. -c10::intrusive_ptr ProcessGroup::allgather_coalesced( +std::shared_ptr ProcessGroup::allgather_coalesced( std::vector>& /* usused */, std::vector& /* usused */, const AllgatherOptions& /* usused */) { diff --git a/torch/lib/c10d/ProcessGroup.hpp b/torch/lib/c10d/ProcessGroup.hpp index 63996b516a06..5e90dccc25c0 100644 --- a/torch/lib/c10d/ProcessGroup.hpp +++ b/torch/lib/c10d/ProcessGroup.hpp @@ -70,11 +70,12 @@ bool isP2POp(OpType opType); // class ProcessGroup { public: + // Please do not use ProcessGroup::Work API, it is going away, to be // replaced by ivalue::Future. // Python binding for this class might change, please do not assume // this will be bound using pybind. - class Work : public torch::CustomClassHolder { + class Work { public: Work(int rank = -1, OpType opType = OpType::UNKNOWN, const char* profilingTitle = nullptr); @@ -170,25 +171,25 @@ class ProcessGroup { return size_; } - virtual c10::intrusive_ptr broadcast( + virtual std::shared_ptr broadcast( std::vector& data, const BroadcastOptions& opts = BroadcastOptions()) = 0; - virtual c10::intrusive_ptr allreduce( + virtual std::shared_ptr allreduce( std::vector& data, const AllreduceOptions& opts = AllreduceOptions()) = 0; // This will be moved out of ProcessGroup, do not add dependencies on this // function. - virtual c10::intrusive_ptr allreduce_coalesced( + virtual std::shared_ptr allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts = AllreduceCoalescedOptions()) = 0; - virtual c10::intrusive_ptr reduce( + virtual std::shared_ptr reduce( std::vector& tensors, const ReduceOptions& opts = ReduceOptions()) = 0; - virtual c10::intrusive_ptr allgather( + virtual std::shared_ptr allgather( std::vector>& outputTensors, std::vector& inputTensors, const AllgatherOptions& opts = AllgatherOptions()) = 0; @@ -196,7 +197,7 @@ class ProcessGroup { // Gathers a single tensor inputBuffer into a single buffer outputBuffer that // is interpreted as a contigious collection of size inputBuffer * WORLD_SIZE. // For implementers of ProcessGroup API and advanced users only. - virtual c10::intrusive_ptr allgather_base( + virtual std::shared_ptr allgather_base( at::Tensor& outputBuffer, at::Tensor& inputBuffer, const AllgatherOptions& opts = AllgatherOptions()) = 0; @@ -205,27 +206,27 @@ class ProcessGroup { // * do not add dependencies on this function, // * do not implement it in your ProcessGroup, implement allgather_base // instead. - virtual c10::intrusive_ptr allgather_coalesced( + virtual std::shared_ptr allgather_coalesced( std::vector>& outputTensorLists, std::vector& inputTensors, const AllgatherOptions& opts = AllgatherOptions()); - virtual c10::intrusive_ptr gather( + virtual std::shared_ptr gather( std::vector>& outputTensors, std::vector& inputTensors, const GatherOptions& opts = GatherOptions()) = 0; - virtual c10::intrusive_ptr scatter( + virtual std::shared_ptr scatter( std::vector& outputTensors, std::vector>& inputTensors, const ScatterOptions& opts = ScatterOptions()) = 0; - virtual c10::intrusive_ptr reduce_scatter( + virtual std::shared_ptr reduce_scatter( std::vector& outputTensors, std::vector>& inputTensors, const ReduceScatterOptions& opts = ReduceScatterOptions()) = 0; - virtual c10::intrusive_ptr alltoall_base( + virtual std::shared_ptr alltoall_base( at::Tensor& outputTensor, at::Tensor& inputTensor, std::vector& outputSplitSizes, @@ -234,28 +235,28 @@ class ProcessGroup { throw std::runtime_error("ProcessGroup does not support alltoall"); } - virtual c10::intrusive_ptr alltoall( + virtual std::shared_ptr alltoall( std::vector& outputTensors, std::vector& inputTensors, const AllToAllOptions& opts = AllToAllOptions()) { throw std::runtime_error("ProcessGroup does not support alltoall"); } - virtual c10::intrusive_ptr send( + virtual std::shared_ptr send( std::vector& tensors, int dstRank, int tag) = 0; - virtual c10::intrusive_ptr recv( + virtual std::shared_ptr recv( std::vector& tensors, int srcRank, int tag) = 0; - virtual c10::intrusive_ptr recvAnysource( + virtual std::shared_ptr recvAnysource( std::vector& tensors, int tag) = 0; - virtual c10::intrusive_ptr barrier( + virtual std::shared_ptr barrier( const BarrierOptions& opts = BarrierOptions()) = 0; protected: diff --git a/torch/lib/c10d/ProcessGroupGloo.cpp b/torch/lib/c10d/ProcessGroupGloo.cpp index 90c9b695de28..cd3e83e6b714 100644 --- a/torch/lib/c10d/ProcessGroupGloo.cpp +++ b/torch/lib/c10d/ProcessGroupGloo.cpp @@ -38,7 +38,6 @@ #endif #include -#include #include #include #include @@ -654,11 +653,11 @@ void ProcessGroupGloo::runLoop(int workerIndex) { AsyncWork::execute(std::move(work)); lock.lock(); - workInProgress_[workerIndex].reset(); + workInProgress_[workerIndex] = nullptr; } } -void ProcessGroupGloo::enqueue(c10::intrusive_ptr work) { +void ProcessGroupGloo::enqueue(std::shared_ptr work) { std::unique_lock lock(workMutex_); workQueue_.push_back(std::move(work)); lock.unlock(); @@ -774,7 +773,7 @@ class AsyncBroadcastCUDAWork : public AsyncBroadcastWork { } // namespace -c10::intrusive_ptr ProcessGroupGloo::broadcast( +std::shared_ptr ProcessGroupGloo::broadcast( std::vector& inputs, const BroadcastOptions& opts) { static auto invalidArgument = [](const std::string& msg) { @@ -797,15 +796,15 @@ c10::intrusive_ptr ProcessGroupGloo::broadcast( invalidArgument(c10::str("unsupported device type ", device.type())); } - c10::intrusive_ptr work; + std::shared_ptr work; auto tag = nextTag(); auto context = getContext(tag); if (device.type() == at::kCPU) { - work = c10::make_intrusive( + work = std::make_shared( std::move(context), inputs, opts.rootRank, opts.rootTensor, tag); #ifdef USE_CUDA } else if (device.type() == at::kCUDA) { - work = c10::make_intrusive( + work = std::make_shared( std::move(context), inputs, opts.rootRank, opts.rootTensor, tag); #endif } else { @@ -1301,7 +1300,7 @@ class AsyncSparseAllreduceCUDAWork : public AsyncSparseAllreduceWork { } // namespace -c10::intrusive_ptr ProcessGroupGloo::allreduce( +std::shared_ptr ProcessGroupGloo::allreduce( std::vector& inputs, const AllreduceOptions& opts) { static auto invalidArgument = [](const std::string& msg) { @@ -1330,15 +1329,15 @@ c10::intrusive_ptr ProcessGroupGloo::allreduce( "(allreduce of sparse tensors only works with ReduceOp.SUM)"); } - c10::intrusive_ptr work; + std::shared_ptr work; auto tag = nextTag(); auto context = getContext(tag); if (device.type() == at::kCPU) { if (layout == c10::kStrided) { - work = c10::make_intrusive( + work = std::make_shared( std::move(context), inputs, opts.reduceOp, tag); } else if (layout == c10::kSparse) { - work = c10::make_intrusive( + work = std::make_shared( std::move(context), inputs, tag); } else { invalidArgument("unsupported layout"); @@ -1346,10 +1345,10 @@ c10::intrusive_ptr ProcessGroupGloo::allreduce( #ifdef USE_CUDA } else if (device.type() == at::kCUDA) { if (layout == c10::kStrided) { - work = c10::make_intrusive( + work = std::make_shared( std::move(context), inputs, opts.reduceOp, tag); } else if (layout == c10::kSparse) { - work = c10::make_intrusive( + work = std::make_shared( std::move(context), inputs, tag); } else { invalidArgument("unsupported layout"); @@ -1363,7 +1362,7 @@ c10::intrusive_ptr ProcessGroupGloo::allreduce( return work; } -c10::intrusive_ptr ProcessGroupGloo::allreduce_coalesced( +std::shared_ptr ProcessGroupGloo::allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts) { static auto invalidArgument = [](const std::string& msg) { @@ -1406,12 +1405,12 @@ c10::intrusive_ptr ProcessGroupGloo::allreduce_coalesced( invalidArgument("unsupported layout"); } - c10::intrusive_ptr work; + std::shared_ptr work; const uint32_t tag = nextTag(); std::shared_ptr context = getContext(tag); if (device.type() == c10::kCPU) { if (layout == c10::kStrided) { - work = c10::make_intrusive( + work = std::make_shared( std::move(context), tensors, opts.reduceOp, tag); } else { invalidArgument("unsupported layout"); @@ -1539,7 +1538,7 @@ class AsyncReduceCUDAWork : public AsyncReduceWork { } // namespace -c10::intrusive_ptr ProcessGroupGloo::reduce( +std::shared_ptr ProcessGroupGloo::reduce( std::vector& inputs, const ReduceOptions& opts) { static auto invalidArgument = [](const std::string& msg) { @@ -1562,11 +1561,11 @@ c10::intrusive_ptr ProcessGroupGloo::reduce( invalidArgument(c10::str("unsupported device type ", device.type())); } - c10::intrusive_ptr work; + std::shared_ptr work; auto tag = nextTag(); auto context = getContext(tag); if (device.type() == at::kCPU) { - work = c10::make_intrusive( + work = std::make_shared( std::move(context), inputs, opts.rootRank, @@ -1575,7 +1574,7 @@ c10::intrusive_ptr ProcessGroupGloo::reduce( tag); #ifdef USE_CUDA } else if (device.type() == at::kCUDA) { - work = c10::make_intrusive( + work = std::make_shared( std::move(context), inputs, opts.rootRank, @@ -1721,7 +1720,7 @@ class AsyncAllgatherCUDAWork : public AsyncAllgatherWork { // Note: current CUDA implementation holds the assumption that the // tensors in the nested output tensor vectors are on the same device. -c10::intrusive_ptr ProcessGroupGloo::allgather( +std::shared_ptr ProcessGroupGloo::allgather( std::vector>& outputs, std::vector& inputs, const AllgatherOptions& opts) { @@ -1770,15 +1769,15 @@ c10::intrusive_ptr ProcessGroupGloo::allgather( invalidArgument(c10::str("unsupported device type ", device.type())); } - c10::intrusive_ptr work; + std::shared_ptr work; auto tag = nextTag(); auto context = getContext(tag); if (device.type() == at::kCPU) { - work = c10::make_intrusive( + work = std::make_shared( std::move(context), outputs, inputs, tag); #ifdef USE_CUDA } else if (device.type() == at::kCUDA) { - work = c10::make_intrusive( + work = std::make_shared( std::move(context), outputs, inputs, tag); #endif } else { @@ -1853,7 +1852,7 @@ class AsyncAllgatherCoalescedWork : public ProcessGroupGloo::AsyncWork { } // namespace -c10::intrusive_ptr ProcessGroupGloo::allgather_coalesced( +std::shared_ptr ProcessGroupGloo::allgather_coalesced( std::vector>& output_lists, std::vector& input_list, const AllgatherOptions& /* unused */) { @@ -1903,13 +1902,13 @@ c10::intrusive_ptr ProcessGroupGloo::allgather_coalesced( auto tag = nextTag(); auto context = getContext(tag); - auto work = c10::make_intrusive( + auto work = std::make_shared( std::move(context), output_lists, input_list, tag); enqueue(work); return work; } -c10::intrusive_ptr ProcessGroupGloo::allgather_base( +std::shared_ptr ProcessGroupGloo::allgather_base( at::Tensor& /*unused */, at::Tensor& /*unused */, const AllgatherOptions& /*unused */) { @@ -2058,7 +2057,7 @@ class AsyncGatherCUDAWork : public AsyncGatherWork { } // namespace -c10::intrusive_ptr ProcessGroupGloo::gather( +std::shared_ptr ProcessGroupGloo::gather( std::vector>& outputs, std::vector& inputs, const GatherOptions& opts) { @@ -2104,15 +2103,15 @@ c10::intrusive_ptr ProcessGroupGloo::gather( invalidArgument(c10::str("unsupported device type ", device.type())); } - c10::intrusive_ptr work; + std::shared_ptr work; auto tag = nextTag(); auto context = getContext(tag); if (device.type() == at::kCPU) { - work = c10::make_intrusive( + work = std::make_shared( std::move(context), outputs, inputs, opts.rootRank, tag); #ifdef USE_CUDA } else if (device.type() == at::kCUDA) { - work = c10::make_intrusive( + work = std::make_shared( std::move(context), outputs, inputs, opts.rootRank, tag); #endif } else { @@ -2246,7 +2245,7 @@ class AsyncScatterCUDAWork : public AsyncScatterWork { } // namespace -c10::intrusive_ptr ProcessGroupGloo::scatter( +std::shared_ptr ProcessGroupGloo::scatter( std::vector& outputs, std::vector>& inputs, const ScatterOptions& opts) { @@ -2291,15 +2290,15 @@ c10::intrusive_ptr ProcessGroupGloo::scatter( invalidArgument(c10::str("unsupported device type ", device.type())); } - c10::intrusive_ptr work; + std::shared_ptr work; auto tag = nextTag(); auto context = getContext(tag); if (device.type() == at::kCPU) { - work = c10::make_intrusive( + work = std::make_shared( std::move(context), outputs, inputs, opts.rootRank, tag); #ifdef USE_CUDA } else if (device.type() == at::kCUDA) { - work = c10::make_intrusive( + work = std::make_shared( std::move(context), outputs, inputs, opts.rootRank, tag); #endif } else { @@ -2309,7 +2308,7 @@ c10::intrusive_ptr ProcessGroupGloo::scatter( return work; } -c10::intrusive_ptr ProcessGroupGloo::reduce_scatter( +std::shared_ptr ProcessGroupGloo::reduce_scatter( std::vector& outputs, std::vector>& inputs, const ReduceScatterOptions& opts) { @@ -2444,7 +2443,7 @@ class AsyncAlltoallCUDAWork : public AsyncAlltoallWork { } // namespace -c10::intrusive_ptr ProcessGroupGloo::alltoall_base( +std::shared_ptr ProcessGroupGloo::alltoall_base( at::Tensor& outputTensor, at::Tensor& inputTensor, std::vector& outputCounts, @@ -2461,12 +2460,12 @@ c10::intrusive_ptr ProcessGroupGloo::alltoall_base( assertDense(invalidArgument, {inputTensor}); const auto& device = outputTensor.device(); - c10::intrusive_ptr work; + std::shared_ptr work; auto tag = nextTag(); auto context = getContext(tag); if (device.type() == at::kCPU) { - work = c10::make_intrusive( + work = std::make_shared( std::move(context), outputTensor, inputTensor, @@ -2475,7 +2474,7 @@ c10::intrusive_ptr ProcessGroupGloo::alltoall_base( tag); #ifdef USE_CUDA } else if (device.type() == at::kCUDA) { - work = c10::make_intrusive( + work = std::make_shared( std::move(context), outputTensor, inputTensor, @@ -2511,7 +2510,7 @@ uint32_t checkTag(int32_t tag) { return (uint32_t)tag; } -c10::intrusive_ptr ProcessGroupGloo::send( +std::shared_ptr ProcessGroupGloo::send( std::vector& tensors, int dstRank, int tag) { @@ -2527,10 +2526,10 @@ c10::intrusive_ptr ProcessGroupGloo::send( // The work captures the tensor to prevent it being deallocated and // the unbound buffer to synchronize on completion of the send. - return c10::make_intrusive(tensor, std::move(buf)); + return std::make_shared(tensor, std::move(buf)); } -c10::intrusive_ptr ProcessGroupGloo::recv( +std::shared_ptr ProcessGroupGloo::recv( std::vector& tensors, int srcRank, int tag) { @@ -2546,10 +2545,10 @@ c10::intrusive_ptr ProcessGroupGloo::recv( // The work captures the tensor to prevent it being deallocated and // the unbound buffer to synchronize on completion of the recv. - return c10::make_intrusive(tensor, std::move(buf)); + return std::make_shared(tensor, std::move(buf)); } -c10::intrusive_ptr ProcessGroupGloo::recvAnysource( +std::shared_ptr ProcessGroupGloo::recvAnysource( std::vector& tensors, int tag) { auto& tensor = checkSingleTensor(tensors); @@ -2574,7 +2573,7 @@ c10::intrusive_ptr ProcessGroupGloo::recvAnysource( // The work captures the tensor to prevent it being deallocated and // the unbound buffer to synchronize on completion of the recv. - return c10::make_intrusive(tensor, std::move(buf)); + return std::make_shared(tensor, std::move(buf)); } namespace { @@ -2583,13 +2582,13 @@ class AsyncBarrierWork : public ProcessGroupGloo::AsyncWork { public: AsyncBarrierWork( const std::shared_ptr& context, - std::vector> priorWork, + std::vector> priorWork, uint32_t tag) : ProcessGroupGloo::AsyncWork("gloo:barrier"), context(context), priorWork(std::move(priorWork)), tag(tag) {} std::shared_ptr context; - std::vector> priorWork; + std::vector> priorWork; const uint32_t tag; void run() override { @@ -2609,9 +2608,9 @@ class AsyncBarrierWork : public ProcessGroupGloo::AsyncWork { } // namespace -c10::intrusive_ptr ProcessGroupGloo::barrier( +std::shared_ptr ProcessGroupGloo::barrier( const BarrierOptions& opts) { - std::vector> priorWork; + std::vector> priorWork; // Snapshot all in progress and pending work as weak_ptr. // When executing a barrier, we need to ensure that all prior work @@ -2625,7 +2624,7 @@ c10::intrusive_ptr ProcessGroupGloo::barrier( auto tag = nextTag(); auto context = getContext(tag); - auto work = c10::make_intrusive( + auto work = std::make_shared( std::move(context), std::move(priorWork), tag); enqueue(work); return work; diff --git a/torch/lib/c10d/ProcessGroupGloo.hpp b/torch/lib/c10d/ProcessGroupGloo.hpp index 74fd0f6e5165..31664ad0b6cf 100644 --- a/torch/lib/c10d/ProcessGroupGloo.hpp +++ b/torch/lib/c10d/ProcessGroupGloo.hpp @@ -70,7 +70,7 @@ class ProcessGroupGloo : public ProcessGroup { public: AsyncWork(const char* profilingTitle = nullptr): ProcessGroup::Work(-1, OpType::UNKNOWN, profilingTitle) {} - static void execute(c10::intrusive_ptr work) { + static void execute(std::shared_ptr work) { std::exception_ptr eptr; try { work->run(); @@ -159,75 +159,75 @@ class ProcessGroupGloo : public ProcessGroup { virtual ~ProcessGroupGloo(); - c10::intrusive_ptr broadcast( + std::shared_ptr broadcast( std::vector& tensors, const BroadcastOptions& opts = BroadcastOptions()) override; - c10::intrusive_ptr allreduce( + std::shared_ptr allreduce( std::vector& tensors, const AllreduceOptions& opts = AllreduceOptions()) override; - c10::intrusive_ptr allreduce_coalesced( + std::shared_ptr allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts = AllreduceCoalescedOptions()) override; - c10::intrusive_ptr reduce( + std::shared_ptr reduce( std::vector& tensors, const ReduceOptions& opts = ReduceOptions()) override; - c10::intrusive_ptr allgather( + std::shared_ptr allgather( std::vector>& outputs, std::vector& inputs, const AllgatherOptions& opts = AllgatherOptions()) override; - c10::intrusive_ptr allgather_base( + std::shared_ptr allgather_base( at::Tensor& outputBuffer, at::Tensor& inputBuffer, const AllgatherOptions& opts = AllgatherOptions()) override; - c10::intrusive_ptr allgather_coalesced( + std::shared_ptr allgather_coalesced( std::vector>& output_lists, std::vector& input_list, const AllgatherOptions& opts = AllgatherOptions()) override; - c10::intrusive_ptr gather( + std::shared_ptr gather( std::vector>& outputs, std::vector& inputs, const GatherOptions& opts = GatherOptions()) override; - c10::intrusive_ptr scatter( + std::shared_ptr scatter( std::vector& outputs, std::vector>& inputs, const ScatterOptions& opts = ScatterOptions()) override; - c10::intrusive_ptr reduce_scatter( + std::shared_ptr reduce_scatter( std::vector& outputs, std::vector>& inputs, const ReduceScatterOptions& opts = ReduceScatterOptions()) override; - c10::intrusive_ptr alltoall_base( + std::shared_ptr alltoall_base( at::Tensor& outputTensor, at::Tensor& inputTensor, std::vector& outputCounts, std::vector& inputCounts, const AllToAllOptions& opts = AllToAllOptions()) override; - c10::intrusive_ptr send( + std::shared_ptr send( std::vector& tensors, int dstRank, int tag) override; - c10::intrusive_ptr recv( + std::shared_ptr recv( std::vector& tensors, int srcRank, int tag) override; - c10::intrusive_ptr recvAnysource( + std::shared_ptr recvAnysource( std::vector& tensors, int tag) override; - c10::intrusive_ptr barrier( + std::shared_ptr barrier( const BarrierOptions& opts = BarrierOptions()) override; protected: @@ -258,7 +258,7 @@ class ProcessGroupGloo : public ProcessGroup { void runLoop(int workerIndex); // Queue work to run on worker thread. - void enqueue(c10::intrusive_ptr work); + void enqueue(std::shared_ptr work); // Keep both a queue of pending work, and a vector with in progress work. // Both of these can only be mutated when holding the queue lock. @@ -266,8 +266,8 @@ class ProcessGroupGloo : public ProcessGroup { // to all in progress and pending work when executing a barrier. // When executing a barrier, we need to ensure that all prior work // has completed before completing itself. - std::deque> workQueue_; - std::vector> workInProgress_; + std::deque> workQueue_; + std::vector> workInProgress_; std::mutex workMutex_; std::condition_variable workProduceCV_; std::condition_variable workConsumeCV_; diff --git a/torch/lib/c10d/ProcessGroupMPI.cpp b/torch/lib/c10d/ProcessGroupMPI.cpp index 5f9d0be41b8f..d3e79a1dd424 100644 --- a/torch/lib/c10d/ProcessGroupMPI.cpp +++ b/torch/lib/c10d/ProcessGroupMPI.cpp @@ -308,9 +308,9 @@ void ProcessGroupMPI::runLoop() { } } -c10::intrusive_ptr ProcessGroupMPI::enqueue( +std::shared_ptr ProcessGroupMPI::enqueue( std::unique_ptr entry) { - auto work = c10::make_intrusive(); + auto work = std::make_shared(); std::unique_lock lock(pgMutex_); queue_.push_back(std::make_tuple(std::move(entry), work)); lock.unlock(); @@ -318,7 +318,7 @@ c10::intrusive_ptr ProcessGroupMPI::enqueue( return work; } -c10::intrusive_ptr ProcessGroupMPI::broadcast( +std::shared_ptr ProcessGroupMPI::broadcast( std::vector& tensors, const BroadcastOptions& opts) { checkSingleTensor(tensors); @@ -339,7 +339,7 @@ c10::intrusive_ptr ProcessGroupMPI::broadcast( return enqueue(std::move(entry)); } -c10::intrusive_ptr ProcessGroupMPI::allreduce( +std::shared_ptr ProcessGroupMPI::allreduce( std::vector& tensors, const AllreduceOptions& opts) { checkSingleTensor(tensors); @@ -362,14 +362,14 @@ c10::intrusive_ptr ProcessGroupMPI::allreduce( return enqueue(std::move(entry)); } -c10::intrusive_ptr ProcessGroupMPI::allreduce_coalesced( +std::shared_ptr ProcessGroupMPI::allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts) { throw std::runtime_error( "allreduce_coalesced is currently not supported with MPI"); } -c10::intrusive_ptr ProcessGroupMPI::reduce( +std::shared_ptr ProcessGroupMPI::reduce( std::vector& tensors, const ReduceOptions& opts) { checkSingleTensor(tensors); @@ -397,7 +397,7 @@ c10::intrusive_ptr ProcessGroupMPI::reduce( return enqueue(std::move(entry)); } -c10::intrusive_ptr ProcessGroupMPI::allgather( +std::shared_ptr ProcessGroupMPI::allgather( std::vector>& outputTensors, std::vector& inputTensors, const AllgatherOptions& opts) { @@ -441,7 +441,7 @@ c10::intrusive_ptr ProcessGroupMPI::allgather( return enqueue(std::move(entry)); } -c10::intrusive_ptr ProcessGroupMPI::allgather_coalesced( +std::shared_ptr ProcessGroupMPI::allgather_coalesced( std::vector>& /* unused */, std::vector& /* unused */, const AllgatherOptions& /* unused */) { @@ -449,7 +449,7 @@ c10::intrusive_ptr ProcessGroupMPI::allgather_coalesced( "ProcessGroupMPI does not support allgather_coalesced"); } -c10::intrusive_ptr ProcessGroupMPI::gather( +std::shared_ptr ProcessGroupMPI::gather( std::vector>& outputTensors, std::vector& inputTensors, const GatherOptions& opts) { @@ -516,7 +516,7 @@ c10::intrusive_ptr ProcessGroupMPI::gather( } } -c10::intrusive_ptr ProcessGroupMPI::scatter( +std::shared_ptr ProcessGroupMPI::scatter( std::vector& outputTensors, std::vector>& inputTensors, const ScatterOptions& opts) { @@ -582,14 +582,14 @@ c10::intrusive_ptr ProcessGroupMPI::scatter( } } -c10::intrusive_ptr ProcessGroupMPI::reduce_scatter( +std::shared_ptr ProcessGroupMPI::reduce_scatter( std::vector& outputTensors, std::vector>& inputTensors, const ReduceScatterOptions& opts) { throw std::runtime_error("ProcessGroupMPI does not support reduce_scatter"); } -c10::intrusive_ptr ProcessGroupMPI::alltoall_base( +std::shared_ptr ProcessGroupMPI::alltoall_base( at::Tensor& outputTensor, at::Tensor& inputTensor, std::vector& outputSplitSizes, @@ -665,7 +665,7 @@ c10::intrusive_ptr ProcessGroupMPI::alltoall_base( return enqueue(std::move(entry)); } } -c10::intrusive_ptr ProcessGroupMPI::alltoall( +std::shared_ptr ProcessGroupMPI::alltoall( std::vector& outputTensors, std::vector& inputTensors, const AllToAllOptions& opts) { @@ -722,7 +722,7 @@ c10::intrusive_ptr ProcessGroupMPI::alltoall( return enqueue(std::move(entry)); } -c10::intrusive_ptr ProcessGroupMPI::send( +std::shared_ptr ProcessGroupMPI::send( std::vector& tensors, int dstRank, int tag) { @@ -744,10 +744,10 @@ c10::intrusive_ptr ProcessGroupMPI::send( &request)); } - return c10::make_intrusive(tensor, request); + return std::make_shared(tensor, request); } -c10::intrusive_ptr ProcessGroupMPI::recv( +std::shared_ptr ProcessGroupMPI::recv( std::vector& tensors, int srcRank, int tag) { @@ -769,10 +769,10 @@ c10::intrusive_ptr ProcessGroupMPI::recv( &request)); } - return c10::make_intrusive(tensor, request); + return std::make_shared(tensor, request); } -c10::intrusive_ptr ProcessGroupMPI::recvAnysource( +std::shared_ptr ProcessGroupMPI::recvAnysource( std::vector& tensors, int tag) { checkSingleTensor(tensors); @@ -793,10 +793,10 @@ c10::intrusive_ptr ProcessGroupMPI::recvAnysource( &request)); } - return c10::make_intrusive(tensor, request); + return std::make_shared(tensor, request); } -c10::intrusive_ptr ProcessGroupMPI::barrier( +std::shared_ptr ProcessGroupMPI::barrier( const BarrierOptions& opts) { std::function&)> runFunc = [this](std::unique_ptr& entry) { @@ -808,7 +808,7 @@ c10::intrusive_ptr ProcessGroupMPI::barrier( return enqueue(std::move(entry)); } -c10::intrusive_ptr ProcessGroupMPI::allgather_base( +std::shared_ptr ProcessGroupMPI::allgather_base( at::Tensor& /*unused */, at::Tensor& /*unused */, const AllgatherOptions& /*unused */) { diff --git a/torch/lib/c10d/ProcessGroupMPI.hpp b/torch/lib/c10d/ProcessGroupMPI.hpp index 48d95eada887..342fe87001a0 100644 --- a/torch/lib/c10d/ProcessGroupMPI.hpp +++ b/torch/lib/c10d/ProcessGroupMPI.hpp @@ -108,80 +108,80 @@ class ProcessGroupMPI : public ProcessGroup { // Abort the MPI program, needs to be called when exception is detected void abort(); - c10::intrusive_ptr broadcast( + std::shared_ptr broadcast( std::vector& data, const BroadcastOptions& opts = BroadcastOptions()) override; - c10::intrusive_ptr allreduce( + std::shared_ptr allreduce( std::vector& tensors, const AllreduceOptions& opts = AllreduceOptions()) override; - c10::intrusive_ptr allreduce_coalesced( + std::shared_ptr allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts = AllreduceCoalescedOptions()) override; - c10::intrusive_ptr reduce( + std::shared_ptr reduce( std::vector& tensors, const ReduceOptions& opts = ReduceOptions()) override; - c10::intrusive_ptr allgather( + std::shared_ptr allgather( std::vector>& outputTensors, std::vector& inputTensors, const AllgatherOptions& opts = AllgatherOptions()) override; - c10::intrusive_ptr allgather_base( + std::shared_ptr allgather_base( at::Tensor& outputbuffer, at::Tensor& inputbuffer, const AllgatherOptions& opts = AllgatherOptions()) override; - c10::intrusive_ptr allgather_coalesced( + std::shared_ptr allgather_coalesced( std::vector>& outputTensorLists, std::vector& inputTensors, const AllgatherOptions& opts = AllgatherOptions()) override; - c10::intrusive_ptr gather( + std::shared_ptr gather( std::vector>& outputTensors, std::vector& inputTensors, const GatherOptions& opts = GatherOptions()) override; - c10::intrusive_ptr scatter( + std::shared_ptr scatter( std::vector& outputTensors, std::vector>& inputTensors, const ScatterOptions& opts = ScatterOptions()) override; - c10::intrusive_ptr reduce_scatter( + std::shared_ptr reduce_scatter( std::vector& outputTensors, std::vector>& inputTensors, const ReduceScatterOptions& opts = ReduceScatterOptions()) override; - c10::intrusive_ptr alltoall_base( + std::shared_ptr alltoall_base( at::Tensor& outputTensor, at::Tensor& inputTensor, std::vector& outputSplitSizes, std::vector& inputSplitSizes, const AllToAllOptions& opts = AllToAllOptions()) override; - c10::intrusive_ptr alltoall( + std::shared_ptr alltoall( std::vector& outputTensors, std::vector& inputTensors, const AllToAllOptions& opts = AllToAllOptions()) override; - c10::intrusive_ptr send( + std::shared_ptr send( std::vector& tensors, int dstRank, int tag); - c10::intrusive_ptr recv( + std::shared_ptr recv( std::vector& tensors, int srcRank, int tag); - c10::intrusive_ptr recvAnysource( + std::shared_ptr recvAnysource( std::vector& tensor, int tag); - c10::intrusive_ptr barrier( + std::shared_ptr barrier( const BarrierOptions& opts = BarrierOptions()) override; // Creating a new ProcessGroupMPI, will initiialize MPI if not initialized @@ -190,13 +190,13 @@ class ProcessGroupMPI : public ProcessGroup { protected: using WorkType = - std::tuple, c10::intrusive_ptr>; + std::tuple, std::shared_ptr>; // Worker thread loop void runLoop(); // Helper function that is called by the destructor void destroy(); - c10::intrusive_ptr enqueue(std::unique_ptr entry); + std::shared_ptr enqueue(std::unique_ptr entry); bool stop_; diff --git a/torch/lib/c10d/ProcessGroupNCCL.cpp b/torch/lib/c10d/ProcessGroupNCCL.cpp index bd1563226343..ba0b4b36c77d 100644 --- a/torch/lib/c10d/ProcessGroupNCCL.cpp +++ b/torch/lib/c10d/ProcessGroupNCCL.cpp @@ -984,12 +984,12 @@ std::vector flatten_for_scatter_gather( } // namespace -c10::intrusive_ptr ProcessGroupNCCL::initWork( +std::shared_ptr ProcessGroupNCCL::initWork( std::vector devices, int rank, OpType opType, const char* profilingTitle) { - return c10::make_intrusive(devices, rank, opType); + return std::make_shared(devices, rank, opType, profilingTitle); } std::vector ProcessGroupNCCL::WorkNCCL::result() { @@ -1012,7 +1012,7 @@ c10::intrusive_ptr ProcessGroupNCCL::WorkNCCL:: } void ProcessGroupNCCL::workEnqueue( - c10::intrusive_ptr work) { + std::shared_ptr work) { if (!terminateProcessGroup_.load()) { std::lock_guard lock(workMetaListMutex_); // Avoid view tensors to be processed in cleanup thread. @@ -1027,7 +1027,7 @@ ProcessGroupNCCL::Options::Options() isHighPriorityStream(false) {} template -c10::intrusive_ptr ProcessGroupNCCL::collective( +std::shared_ptr ProcessGroupNCCL::collective( std::vector& inputs, std::vector& outputs, Fn fn, @@ -1114,7 +1114,7 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( } template -c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( +std::shared_ptr ProcessGroupNCCL::pointToPoint( std::vector& tensors, Fn fn, int peer, @@ -1186,7 +1186,7 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( } template -c10::intrusive_ptr ProcessGroupNCCL::collective( +std::shared_ptr ProcessGroupNCCL::collective( std::vector& inputs, std::vector& outputs, Fn fn, @@ -1203,7 +1203,7 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( } template -c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( +std::shared_ptr ProcessGroupNCCL::pointToPoint( std::vector& tensor, Fn fn, int peer, @@ -1217,7 +1217,7 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( [](std::vector&) {}); } -c10::intrusive_ptr ProcessGroupNCCL::allreduce( +std::shared_ptr ProcessGroupNCCL::allreduce( std::vector& tensors, const AllreduceOptions& opts) { check_gpu_tensors(tensors); @@ -1242,14 +1242,14 @@ c10::intrusive_ptr ProcessGroupNCCL::allreduce( "nccl:all_reduce"); } -c10::intrusive_ptr ProcessGroupNCCL::allreduce_coalesced( +std::shared_ptr ProcessGroupNCCL::allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts) { throw std::runtime_error( "allreduce_coalesced is currently not supported with NCCL"); } -c10::intrusive_ptr ProcessGroupNCCL::broadcast( +std::shared_ptr ProcessGroupNCCL::broadcast( std::vector& tensors, const BroadcastOptions& opts) { check_gpu_tensors(tensors); @@ -1274,7 +1274,7 @@ c10::intrusive_ptr ProcessGroupNCCL::broadcast( "nccl:broadcast"); } -c10::intrusive_ptr ProcessGroupNCCL::reduce( +std::shared_ptr ProcessGroupNCCL::reduce( std::vector& tensors, const ReduceOptions& opts) { check_gpu_tensors(tensors); @@ -1301,7 +1301,7 @@ c10::intrusive_ptr ProcessGroupNCCL::reduce( "nccl:reduce"); } -c10::intrusive_ptr ProcessGroupNCCL::allgather( +std::shared_ptr ProcessGroupNCCL::allgather( std::vector>& outputTensors, std::vector& inputTensors, const AllgatherOptions& opts) { @@ -1346,7 +1346,7 @@ c10::intrusive_ptr ProcessGroupNCCL::allgather( "nccl:all_gather"); } -c10::intrusive_ptr ProcessGroupNCCL::allgather_coalesced( +std::shared_ptr ProcessGroupNCCL::allgather_coalesced( std::vector>& /* unused */, std::vector& /* unused */, const AllgatherOptions& /* unused */) { @@ -1354,7 +1354,7 @@ c10::intrusive_ptr ProcessGroupNCCL::allgather_coalesced( "ProcessGroupNCCL does not support allgather_coalesced"); } -c10::intrusive_ptr ProcessGroupNCCL::reduce_scatter( +std::shared_ptr ProcessGroupNCCL::reduce_scatter( std::vector& outputTensors, std::vector>& inputTensors, const ReduceScatterOptions& opts) { @@ -1400,7 +1400,7 @@ c10::intrusive_ptr ProcessGroupNCCL::reduce_scatter( "nccl:reduce_scatter"); } -c10::intrusive_ptr ProcessGroupNCCL::barrier( +std::shared_ptr ProcessGroupNCCL::barrier( const BarrierOptions& opts) { std::vector devices; if (usedDeviceIdxs_.empty()) { @@ -1441,7 +1441,7 @@ c10::intrusive_ptr ProcessGroupNCCL::barrier( } #ifdef ENABLE_NCCL_P2P_SUPPORT -c10::intrusive_ptr ProcessGroupNCCL::alltoall_base( +std::shared_ptr ProcessGroupNCCL::alltoall_base( at::Tensor& outputTensor, at::Tensor& inputTensor, std::vector& outputSplitSizes, @@ -1512,7 +1512,7 @@ c10::intrusive_ptr ProcessGroupNCCL::alltoall_base( } } -c10::intrusive_ptr ProcessGroupNCCL::send( +std::shared_ptr ProcessGroupNCCL::send( std::vector& tensors, int dstRank, int /* unused */) { @@ -1531,7 +1531,7 @@ c10::intrusive_ptr ProcessGroupNCCL::send( return ret; } -c10::intrusive_ptr ProcessGroupNCCL::recv( +std::shared_ptr ProcessGroupNCCL::recv( std::vector& tensors, int srcRank, int /* unused */) { @@ -1550,7 +1550,7 @@ c10::intrusive_ptr ProcessGroupNCCL::recv( return ret; } #else -c10::intrusive_ptr ProcessGroupNCCL::alltoall_base( +std::shared_ptr ProcessGroupNCCL::alltoall_base( at::Tensor& /* unused */, at::Tensor& /* unused */, std::vector& /* unused */, @@ -1560,7 +1560,7 @@ c10::intrusive_ptr ProcessGroupNCCL::alltoall_base( "ProcessGroupNCCL only supports alltoall* for NCCL lib version >= 2.7.0"); } -c10::intrusive_ptr ProcessGroupNCCL::send( +std::shared_ptr ProcessGroupNCCL::send( std::vector& /* unused */, int /* unused */, int /* unused */) { @@ -1568,7 +1568,7 @@ c10::intrusive_ptr ProcessGroupNCCL::send( "ProcessGroupNCCL only supports send for NCCL lib version >= 2.7.0"); } -c10::intrusive_ptr ProcessGroupNCCL::recv( +std::shared_ptr ProcessGroupNCCL::recv( std::vector& /* unused */, int /* unused */, int /* unused */) { @@ -1591,34 +1591,34 @@ void ProcessGroupNCCL::groupEnd() { --ncclActiveGroupCounter_; } -c10::intrusive_ptr ProcessGroupNCCL::alltoall( +std::shared_ptr ProcessGroupNCCL::alltoall( std::vector& /* unused */, std::vector& /* unused */, const AllToAllOptions& /* unused */) { throw std::runtime_error("ProcessGroupNCCL does not support alltoall"); } -c10::intrusive_ptr ProcessGroupNCCL::gather( +std::shared_ptr ProcessGroupNCCL::gather( std::vector>& /* unused */, std::vector& /* unused */, const GatherOptions& /* unused */) { throw std::runtime_error("ProcessGroupNCCL does not support gather"); } -c10::intrusive_ptr ProcessGroupNCCL::scatter( +std::shared_ptr ProcessGroupNCCL::scatter( std::vector& /* unused */, std::vector>& /* unused */, const ScatterOptions& /* unused */) { throw std::runtime_error("ProcessGroupNCCL does not support scatter"); } -c10::intrusive_ptr ProcessGroupNCCL::recvAnysource( +std::shared_ptr ProcessGroupNCCL::recvAnysource( std::vector& /* unused */, int /* unused */) { throw std::runtime_error("ProcessGroupNCCL does not support recvAnysource"); } -c10::intrusive_ptr ProcessGroupNCCL::allgather_base( +std::shared_ptr ProcessGroupNCCL::allgather_base( at::Tensor& /*unused */, at::Tensor& /*unused */, const AllgatherOptions& /*unused */) { diff --git a/torch/lib/c10d/ProcessGroupNCCL.hpp b/torch/lib/c10d/ProcessGroupNCCL.hpp index 59f06fda1ec1..1520604629f2 100644 --- a/torch/lib/c10d/ProcessGroupNCCL.hpp +++ b/torch/lib/c10d/ProcessGroupNCCL.hpp @@ -65,7 +65,7 @@ constexpr const char* NCCL_ASYNC_ERROR_HANDLING = "NCCL_ASYNC_ERROR_HANDLING"; class ProcessGroupNCCL : public ProcessGroup { public: class WorkNCCL : public ProcessGroup::Work, - public std::enable_shared_from_this { + public std::enable_shared_from_this { public: // Constructor takes a list of CUDA devices WorkNCCL(const std::vector& devices, int rank, OpType opType, const char* profilingTitle = nullptr); @@ -411,64 +411,64 @@ class ProcessGroupNCCL : public ProcessGroup { virtual ~ProcessGroupNCCL(); - c10::intrusive_ptr broadcast( + std::shared_ptr broadcast( std::vector& tensors, const BroadcastOptions& opts = BroadcastOptions()) override; - c10::intrusive_ptr allreduce( + std::shared_ptr allreduce( std::vector& tensors, const AllreduceOptions& opts = AllreduceOptions()) override; - c10::intrusive_ptr allreduce_coalesced( + std::shared_ptr allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts = AllreduceCoalescedOptions()) override; - c10::intrusive_ptr reduce( + std::shared_ptr reduce( std::vector& tensors, const ReduceOptions& opts = ReduceOptions()) override; - c10::intrusive_ptr allgather( + std::shared_ptr allgather( std::vector>& outputTensors, std::vector& inputTensors, const AllgatherOptions& opts = AllgatherOptions()) override; - c10::intrusive_ptr allgather_base( + std::shared_ptr allgather_base( at::Tensor& outputbuffer, at::Tensor& inputbuffer, const AllgatherOptions& opts = AllgatherOptions()) override; - c10::intrusive_ptr allgather_coalesced( + std::shared_ptr allgather_coalesced( std::vector>& outputTensorLists, std::vector& inputTensors, const AllgatherOptions& opts = AllgatherOptions()) override; - c10::intrusive_ptr reduce_scatter( + std::shared_ptr reduce_scatter( std::vector& outputTensors, std::vector>& inputTensors, const ReduceScatterOptions& opts = ReduceScatterOptions()) override; - c10::intrusive_ptr barrier( + std::shared_ptr barrier( const BarrierOptions& opts = BarrierOptions()) override; - c10::intrusive_ptr alltoall_base( + std::shared_ptr alltoall_base( at::Tensor& outputTensor, at::Tensor& inputTensor, std::vector& outputSplitSizes, std::vector& inputSplitSizes, const AllToAllOptions& opts = AllToAllOptions()) override; - c10::intrusive_ptr alltoall( + std::shared_ptr alltoall( std::vector& outputTensors, std::vector& inputTensors, const AllToAllOptions& opts = AllToAllOptions()) override; - c10::intrusive_ptr send( + std::shared_ptr send( std::vector& tensors, int dstRank, int tag) override; - c10::intrusive_ptr recv( + std::shared_ptr recv( std::vector& tensors, int srcRank, int tag) override; @@ -478,17 +478,17 @@ class ProcessGroupNCCL : public ProcessGroup { static void groupEnd(); // Unsupported Ops - c10::intrusive_ptr gather( + std::shared_ptr gather( std::vector>& outputTensors, std::vector& inputTensors, const GatherOptions& opts = GatherOptions()) override; - c10::intrusive_ptr scatter( + std::shared_ptr scatter( std::vector& outputTensors, std::vector>& inputTensors, const ScatterOptions& opts = ScatterOptions()) override; - c10::intrusive_ptr recvAnysource( + std::shared_ptr recvAnysource( std::vector& tensors, int tag) override; @@ -515,7 +515,7 @@ class ProcessGroupNCCL : public ProcessGroup { virtual std::exception_ptr checkForNCCLErrors( const std::vector>& ncclComms); - virtual c10::intrusive_ptr initWork( + virtual std::shared_ptr initWork( std::vector devices, int rank, OpType opType, @@ -529,14 +529,14 @@ class ProcessGroupNCCL : public ProcessGroup { // ncclComm_t, at::cuda::CUDAStream&); // void {pre,post}(std::vector); template - c10::intrusive_ptr collective( + std::shared_ptr collective( std::vector& input, std::vector& output, Fn fn, OpType opType, const char* profilingTitle = nullptr); template - c10::intrusive_ptr collective( + std::shared_ptr collective( std::vector& input, std::vector& output, Fn fn, @@ -549,13 +549,13 @@ class ProcessGroupNCCL : public ProcessGroup { // primitives. It is the same structure as the helper used for collective // communicaiton primitives. template - c10::intrusive_ptr pointToPoint( + std::shared_ptr pointToPoint( std::vector& tensor, Fn fn, int peer, OpType opType); template - c10::intrusive_ptr pointToPoint( + std::shared_ptr pointToPoint( std::vector& tensor, Fn fn, int peer, @@ -664,7 +664,7 @@ class ProcessGroupNCCL : public ProcessGroup { std::list workMetaList_; // Add Work Pointer to workVector - void workEnqueue(c10::intrusive_ptr); + void workEnqueue(std::shared_ptr); // The CUDA steams used by NCCL kernels std::unordered_map> diff --git a/torch/lib/c10d/ProcessGroupRoundRobin.cpp b/torch/lib/c10d/ProcessGroupRoundRobin.cpp index c77188577a62..032f63c320f5 100644 --- a/torch/lib/c10d/ProcessGroupRoundRobin.cpp +++ b/torch/lib/c10d/ProcessGroupRoundRobin.cpp @@ -17,66 +17,66 @@ ProcessGroupRoundRobin::ProcessGroupRoundRobin( ProcessGroupRoundRobin::~ProcessGroupRoundRobin() {} -c10::intrusive_ptr ProcessGroupRoundRobin::broadcast( +std::shared_ptr ProcessGroupRoundRobin::broadcast( std::vector& tensors, const BroadcastOptions& opts) { return next()->broadcast(tensors, opts); } -c10::intrusive_ptr ProcessGroupRoundRobin::allreduce( +std::shared_ptr ProcessGroupRoundRobin::allreduce( std::vector& tensors, const AllreduceOptions& opts) { return next()->allreduce(tensors, opts); } -c10::intrusive_ptr ProcessGroupRoundRobin::allreduce_coalesced( +std::shared_ptr ProcessGroupRoundRobin::allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts) { return next()->allreduce_coalesced(tensors, opts); } -c10::intrusive_ptr ProcessGroupRoundRobin::reduce( +std::shared_ptr ProcessGroupRoundRobin::reduce( std::vector& tensors, const ReduceOptions& opts) { return next()->reduce(tensors, opts); } -c10::intrusive_ptr ProcessGroupRoundRobin::allgather( +std::shared_ptr ProcessGroupRoundRobin::allgather( std::vector>& outputs, std::vector& inputs, const AllgatherOptions& opts) { return next()->allgather(outputs, inputs, opts); }; -c10::intrusive_ptr ProcessGroupRoundRobin::allgather_coalesced( +std::shared_ptr ProcessGroupRoundRobin::allgather_coalesced( std::vector>& outputTensorLists, std::vector& inputTensors, const AllgatherOptions& opts) { return next()->allgather(outputTensorLists, inputTensors, opts); } -c10::intrusive_ptr ProcessGroupRoundRobin::gather( +std::shared_ptr ProcessGroupRoundRobin::gather( std::vector>& outputs, std::vector& inputs, const GatherOptions& opts) { return next()->gather(outputs, inputs, opts); }; -c10::intrusive_ptr ProcessGroupRoundRobin::scatter( +std::shared_ptr ProcessGroupRoundRobin::scatter( std::vector& outputs, std::vector>& inputs, const ScatterOptions& opts) { return next()->scatter(outputs, inputs, opts); }; -c10::intrusive_ptr ProcessGroupRoundRobin::reduce_scatter( +std::shared_ptr ProcessGroupRoundRobin::reduce_scatter( std::vector& outputs, std::vector>& inputs, const ReduceScatterOptions& opts) { return next()->reduce_scatter(outputs, inputs, opts); }; -c10::intrusive_ptr ProcessGroupRoundRobin::alltoall_base( +std::shared_ptr ProcessGroupRoundRobin::alltoall_base( at::Tensor& outputTensor, at::Tensor& inputTensor, std::vector& outputSplitSizes, @@ -86,27 +86,27 @@ c10::intrusive_ptr ProcessGroupRoundRobin::alltoall_base( outputTensor, inputTensor, outputSplitSizes, inputSplitSizes, opts); }; -c10::intrusive_ptr ProcessGroupRoundRobin::send( +std::shared_ptr ProcessGroupRoundRobin::send( std::vector& /* unused */, int /* unused */, int /* unused */) { throw std::runtime_error("ProcessGroupRoundRobin does not support send"); }; -c10::intrusive_ptr ProcessGroupRoundRobin::recv( +std::shared_ptr ProcessGroupRoundRobin::recv( std::vector& /* unused */, int /* unused */, int /* unused */) { throw std::runtime_error("ProcessGroupRoundRobin does not support recv"); }; -c10::intrusive_ptr ProcessGroupRoundRobin::recvAnysource( +std::shared_ptr ProcessGroupRoundRobin::recvAnysource( std::vector& /* unused */, int /* unused */) { throw std::runtime_error("ProcessGroupRoundRobin does not support recv"); }; -c10::intrusive_ptr ProcessGroupRoundRobin::barrier( +std::shared_ptr ProcessGroupRoundRobin::barrier( const BarrierOptions& /* unused */) { throw std::runtime_error("ProcessGroupRoundRobin does not support barrier"); }; @@ -120,7 +120,7 @@ const std::shared_ptr& ProcessGroupRoundRobin::next() { return processGroup; } -c10::intrusive_ptr ProcessGroupRoundRobin::allgather_base( +std::shared_ptr ProcessGroupRoundRobin::allgather_base( at::Tensor& /*unused */, at::Tensor& /*unused */, const AllgatherOptions& /*unused */) { diff --git a/torch/lib/c10d/ProcessGroupRoundRobin.hpp b/torch/lib/c10d/ProcessGroupRoundRobin.hpp index 62d59ef18ce5..bbbd0a1c756b 100644 --- a/torch/lib/c10d/ProcessGroupRoundRobin.hpp +++ b/torch/lib/c10d/ProcessGroupRoundRobin.hpp @@ -25,75 +25,75 @@ class ProcessGroupRoundRobin final : public ProcessGroup { ~ProcessGroupRoundRobin() override; - c10::intrusive_ptr broadcast( + std::shared_ptr broadcast( std::vector& tensors, const BroadcastOptions& opts = BroadcastOptions()) override; - c10::intrusive_ptr allreduce( + std::shared_ptr allreduce( std::vector& tensors, const AllreduceOptions& opts = AllreduceOptions()) override; - c10::intrusive_ptr allreduce_coalesced( + std::shared_ptr allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts = AllreduceCoalescedOptions()) override; - c10::intrusive_ptr reduce( + std::shared_ptr reduce( std::vector& tensors, const ReduceOptions& opts = ReduceOptions()) override; - c10::intrusive_ptr allgather( + std::shared_ptr allgather( std::vector>& outputs, std::vector& inputs, const AllgatherOptions& opts = AllgatherOptions()) override; - c10::intrusive_ptr allgather_base( + std::shared_ptr allgather_base( at::Tensor& outputBuffer, at::Tensor& inputBuffer, const AllgatherOptions& opts = AllgatherOptions()) override; - c10::intrusive_ptr allgather_coalesced( + std::shared_ptr allgather_coalesced( std::vector>& outputTensorLists, std::vector& inputTensors, const AllgatherOptions& opts = AllgatherOptions()) override; - c10::intrusive_ptr gather( + std::shared_ptr gather( std::vector>& outputs, std::vector& inputs, const GatherOptions& opts = GatherOptions()) override; - c10::intrusive_ptr scatter( + std::shared_ptr scatter( std::vector& outputs, std::vector>& inputs, const ScatterOptions& opts = ScatterOptions()) override; - c10::intrusive_ptr reduce_scatter( + std::shared_ptr reduce_scatter( std::vector& outputs, std::vector>& inputs, const ReduceScatterOptions& opts = ReduceScatterOptions()) override; - c10::intrusive_ptr alltoall_base( + std::shared_ptr alltoall_base( at::Tensor& outputTensor, at::Tensor& inputTensor, std::vector& outputSplitSizes, std::vector& inputSplitSizes, const AllToAllOptions& opts = AllToAllOptions()) override; - c10::intrusive_ptr send( + std::shared_ptr send( std::vector& tensors, int dstRank, int tag) override; - c10::intrusive_ptr recv( + std::shared_ptr recv( std::vector& tensors, int srcRank, int tag) override; - c10::intrusive_ptr recvAnysource( + std::shared_ptr recvAnysource( std::vector& tensors, int tag) override; - c10::intrusive_ptr barrier( + std::shared_ptr barrier( const BarrierOptions& opts = BarrierOptions()) override; private: diff --git a/torch/lib/c10d/comm.cpp b/torch/lib/c10d/comm.cpp index 5ef88f058aca..a8628e0c942e 100644 --- a/torch/lib/c10d/comm.cpp +++ b/torch/lib/c10d/comm.cpp @@ -45,10 +45,8 @@ class BroadcastWork { // because c10d::ProcessGroup::broadcast takes a vector argument. std::vector flat_tensor_; - private: - // The broadcast work that is kicked off upon construction. - c10::intrusive_ptr work_; + std::shared_ptr work_; }; } // namespace diff --git a/torch/lib/c10d/example/allreduce.cpp b/torch/lib/c10d/example/allreduce.cpp index 3de7447d092a..76d6a5588f7e 100644 --- a/torch/lib/c10d/example/allreduce.cpp +++ b/torch/lib/c10d/example/allreduce.cpp @@ -19,7 +19,7 @@ int main(int argc, char** argv) { } // Kick off work - std::vector> pending; + std::vector> pending; for (auto i = 0; i < ntensors; i++) { std::vector tmp = {tensors[i]}; pending.push_back(pg.allreduce(tmp)); diff --git a/torch/lib/c10d/reducer.cpp b/torch/lib/c10d/reducer.cpp index c5ee54a9ee8e..c05ce685bb7d 100644 --- a/torch/lib/c10d/reducer.cpp +++ b/torch/lib/c10d/reducer.cpp @@ -472,7 +472,7 @@ std::vector> Reducer::get_bucket_tensors() const { } void Reducer::set_forward_pass_work_handle( - c10::intrusive_ptr forwardPassWorkHandle, + std::shared_ptr forwardPassWorkHandle, bool useStaticWorldSize) { std::lock_guard lock(mutex_); forwardPassWorkHandle_.workHandle = std::move(forwardPassWorkHandle); diff --git a/torch/lib/c10d/reducer.hpp b/torch/lib/c10d/reducer.hpp index e0fe0004f88e..4874f0dd8703 100644 --- a/torch/lib/c10d/reducer.hpp +++ b/torch/lib/c10d/reducer.hpp @@ -8,7 +8,6 @@ #include #include -#include #include #include #include @@ -97,7 +96,7 @@ class Reducer { // Creates and sets ForwardPassWorkHandle given a ProcessGroup::Work and the // corresponding tensor being reduced. void set_forward_pass_work_handle( - c10::intrusive_ptr forwardPassWorkHandle, + std::shared_ptr forwardPassWorkHandle, bool useStaticWorldSize); // Retrieve on-device tensors used to track locally unused parameters. For @@ -159,7 +158,7 @@ class Reducer { bool local_used_maps_reduced_; // Work handle for allreduce on local_used_maps_ - c10::intrusive_ptr local_used_work_; + std::shared_ptr local_used_work_; void verify_replicas_within_process(); @@ -283,7 +282,7 @@ class Reducer { size_t pending; // Keep work handle around when this set of buckets is being reduced. - c10::intrusive_ptr work; + std::shared_ptr work; // Keep future work handle around if DDP comm hook is registered. c10::intrusive_ptr future_work; @@ -341,7 +340,7 @@ class Reducer { // A struct containing work handle and tensor for allreduce scheduled in // forward pass, if applicable. struct ForwardPassAllreduceWork { - c10::intrusive_ptr workHandle; + std::shared_ptr workHandle; at::Tensor resultTensor; // whether we should divide by the initial world_size or the no. of // remaining DDP ranks. diff --git a/torch/lib/c10d/test/ProcessGroupGlooAsyncTest.cpp b/torch/lib/c10d/test/ProcessGroupGlooAsyncTest.cpp index 1363a842eab3..92dede9a573e 100644 --- a/torch/lib/c10d/test/ProcessGroupGlooAsyncTest.cpp +++ b/torch/lib/c10d/test/ProcessGroupGlooAsyncTest.cpp @@ -93,7 +93,7 @@ class AsyncInputIsOutputTest : public AsyncTest { } } - void wait(c10::intrusive_ptr& work) { + void wait(std::shared_ptr& work) { at::cuda::CUDAMultiStreamGuard guard(streams_); work->wait(); } @@ -130,7 +130,7 @@ class AsyncAllreduceTest : public AsyncInputIsOutputTest { AsyncAllreduceTest(const std::string& path, int numTensors) : AsyncInputIsOutputTest(path, numTensors) {} - c10::intrusive_ptr run() { + std::shared_ptr run() { // For the duration of this function, make THC use our streams at::cuda::CUDAMultiStreamGuard guard(streams_); @@ -156,7 +156,7 @@ class AsyncBroadcastTest : public AsyncInputIsOutputTest { AsyncBroadcastTest(const std::string& path, int numTensors) : AsyncInputIsOutputTest(path, numTensors) {} - c10::intrusive_ptr run(int rootRank, int rootTensor) { + std::shared_ptr run(int rootRank, int rootTensor) { // For the duration of this function, make THC use our streams at::cuda::CUDAMultiStreamGuard guard(streams_); @@ -185,7 +185,7 @@ void runAsyncAllreduceTest( size_t numProcesses = 4, size_t numTensors = 2) { auto tests = initialize(path, numProcesses, numTensors); - std::vector> work(numProcesses); + std::vector> work(numProcesses); for (size_t i = 0; i < numProcesses; i++) { work[i] = tests[i].run(); } @@ -229,7 +229,7 @@ void runAsyncBroadcastTest( // Try every permutation of root rank and root tensor for (size_t rootRank = 0; rootRank < numProcesses; rootRank++) { for (size_t rootTensor = 0; rootTensor < numTensors; rootTensor++) { - std::vector> work(numProcesses); + std::vector> work(numProcesses); for (size_t i = 0; i < numProcesses; i++) { work[i] = tests[i].run(rootRank, rootTensor); } diff --git a/torch/lib/c10d/test/ProcessGroupGlooTest.cpp b/torch/lib/c10d/test/ProcessGroupGlooTest.cpp index de993a1110b4..da4f9b5fc106 100644 --- a/torch/lib/c10d/test/ProcessGroupGlooTest.cpp +++ b/torch/lib/c10d/test/ProcessGroupGlooTest.cpp @@ -44,7 +44,7 @@ class SignalTest { }); } - c10::intrusive_ptr<::c10d::ProcessGroup::Work> run(int rank, int size) { + std::shared_ptr<::c10d::ProcessGroup::Work> run(int rank, int size) { auto store = std::make_shared<::c10d::FileStore>(path_, size); ::c10d::ProcessGroupGloo::Options options; @@ -62,7 +62,7 @@ class SignalTest { }; // Loop until an exception happens - c10::intrusive_ptr<::c10d::ProcessGroup::Work> work; + std::shared_ptr<::c10d::ProcessGroup::Work> work; while (true) { work = pg.allreduce(tensors); try { @@ -82,7 +82,7 @@ class SignalTest { Semaphore sem_; }; -c10::intrusive_ptr<::c10d::ProcessGroup::Work> testSignal( +std::shared_ptr<::c10d::ProcessGroup::Work> testSignal( const std::string& path, int signal) { Fork fork; @@ -107,7 +107,7 @@ class ProcessGroupGlooDelayed : public ::c10d::ProcessGroupGloo { Options options) : ProcessGroupGloo(store, rank, size, options) {} - c10::intrusive_ptr<::c10d::ProcessGroup::Work> send( + std::shared_ptr<::c10d::ProcessGroup::Work> send( std::vector& tensors, int dstRank, int tag) override { @@ -200,7 +200,7 @@ void testAllreduce(const std::string& path, const at::DeviceType b) { } // Kick off work - std::vector> work(size); + std::vector> work(size); for (auto i = 0; i < size; i++) { work[i] = tests[i].getProcessGroup().allreduce(inputs[i]); } @@ -250,7 +250,7 @@ void testBroadcast(const std::string& path, const at::DeviceType b) { options.rootTensor = j; // Kick off work - std::vector> work(size); + std::vector> work(size); for (auto i = 0; i < size; i++) { work[i] = tests[i].getProcessGroup().broadcast(inputs[i], options); } @@ -316,7 +316,7 @@ void testAlltoall(const std::string& path, const at::DeviceType b) { }; // Kick off work - std::vector> work(size); + std::vector> work(size); for (auto rank = 0; rank < size; rank++) { work[rank] = tests[rank].getProcessGroup().alltoall_base( outputs[rank], inputs[rank], outputSplits[rank], inputSplits[rank]); @@ -349,7 +349,7 @@ void testBarrier(const std::string& path) { auto tests = CollectiveTest::initialize(path, size); // Kick off work - std::vector> work(size); + std::vector> work(size); for (auto i = 0; i < size; i++) { work[i] = tests[i].getProcessGroup().barrier(); } diff --git a/torch/lib/c10d/test/ProcessGroupMPITest.cpp b/torch/lib/c10d/test/ProcessGroupMPITest.cpp index 6c60b3d6742d..3f5a9e4cf331 100644 --- a/torch/lib/c10d/test/ProcessGroupMPITest.cpp +++ b/torch/lib/c10d/test/ProcessGroupMPITest.cpp @@ -14,7 +14,7 @@ // Wait for work to complete void waitWork( std::shared_ptr pg, - std::vector> works) { + std::vector> works) { for (auto& work : works) { try { work->wait(); @@ -34,11 +34,10 @@ void testAllreduce(int iter = 1000) { allTensors[i] = std::vector({tensor}); } - std::vector> works; + std::vector> works; for (auto& tensors : allTensors) { // Kick off work - c10::intrusive_ptr<::c10d::ProcessGroup::Work> work = - pg->allreduce(tensors); + std::shared_ptr<::c10d::ProcessGroup::Work> work = pg->allreduce(tensors); works.push_back(std::move(work)); } @@ -74,11 +73,10 @@ void testBroadcast(int iter = 10000) { } } - std::vector> works; + std::vector> works; for (auto& tensors : allTensors) { // Kick off work - c10::intrusive_ptr<::c10d::ProcessGroup::Work> work = - pg->broadcast(tensors); + std::shared_ptr<::c10d::ProcessGroup::Work> work = pg->broadcast(tensors); works.push_back(std::move(work)); } @@ -106,10 +104,10 @@ void testReduce(int iter = 10000) { allTensors[i] = std::vector({tensor}); } - std::vector> works; + std::vector> works; for (auto& tensors : allTensors) { // Kick off work - c10::intrusive_ptr<::c10d::ProcessGroup::Work> work = pg->reduce(tensors); + std::shared_ptr<::c10d::ProcessGroup::Work> work = pg->reduce(tensors); works.push_back(std::move(work)); } @@ -152,10 +150,10 @@ void testAllgather(int iter = 10000) { } } - std::vector> works; + std::vector> works; for (size_t i = 0; i < allTensors.size(); ++i) { // Kick off work - c10::intrusive_ptr<::c10d::ProcessGroup::Work> work = + std::shared_ptr<::c10d::ProcessGroup::Work> work = pg->allgather(allOutputTensors[i], allTensors[i]); works.push_back(std::move(work)); } @@ -200,10 +198,10 @@ void testGather(int iter = 10000) { } } - std::vector> works; + std::vector> works; for (size_t i = 0; i < allTensors.size(); ++i) { // Kick off work - c10::intrusive_ptr<::c10d::ProcessGroup::Work> work = + std::shared_ptr<::c10d::ProcessGroup::Work> work = pg->gather(allOutputTensors[i], allTensors[i]); works.push_back(std::move(work)); } @@ -251,10 +249,10 @@ void testScatter(int iter = 1) { } } - std::vector> works; + std::vector> works; for (size_t i = 0; i < allTensors.size(); ++i) { // Kick off work - c10::intrusive_ptr<::c10d::ProcessGroup::Work> work = + std::shared_ptr<::c10d::ProcessGroup::Work> work = pg->scatter(allTensors[i], allInputTensors[i]); works.push_back(std::move(work)); } @@ -291,27 +289,27 @@ void testSendRecv(bool recvAnysource, int iter = 10000) { } if (rank == 0) { - std::vector> works; + std::vector> works; for (auto& tensors : allTensors) { // Kick off work - c10::intrusive_ptr<::c10d::ProcessGroup::Work> work = + std::shared_ptr<::c10d::ProcessGroup::Work> work = pg->send(tensors, 1, 0); works.push_back(std::move(work)); } waitWork(pg, works); } if (rank == 1) { - std::vector> works; + std::vector> works; std::vector srcRanks(allTensors.size(), -1); size_t i = 0; for (auto& tensors : allTensors) { // Kick off work if (!recvAnysource) { - c10::intrusive_ptr<::c10d::ProcessGroup::Work> work = + std::shared_ptr<::c10d::ProcessGroup::Work> work = pg->recv(tensors, 0, 0); works.push_back(std::move(work)); } else { - c10::intrusive_ptr<::c10d::ProcessGroup::Work> work = + std::shared_ptr<::c10d::ProcessGroup::Work> work = pg->recvAnysource(tensors, 0); works.push_back(std::move(work)); } diff --git a/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp b/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp index f1348922e126..e906702a889d 100644 --- a/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp +++ b/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp @@ -56,12 +56,12 @@ class ProcessGroupNCCLSimulateErrors : public c10d::ProcessGroupNCCL { ProcessGroupNCCLSimulateErrors::kWatchdogThreadSleepMillis); } - c10::intrusive_ptr initWork( + std::shared_ptr initWork( std::vector devices, int rank, c10d::OpType opType, const char* profilingTitle) override { - return c10::make_intrusive( + return std::make_shared( devices, simulate_error_, rank, opType); } @@ -113,12 +113,12 @@ class ProcessGroupNCCLTimedOutErrors : public ProcessGroupNCCLSimulateErrors { : ProcessGroupNCCLSimulateErrors(store, rank, size, opts), set_timedout_error_(false) {} - c10::intrusive_ptr initWork( + std::shared_ptr initWork( std::vector devices, int rank, c10d::OpType opType, const char* profilingTitle) override { - return c10::make_intrusive( + return std::make_shared( devices, set_timedout_error_, rank, opType); } diff --git a/torch/lib/c10d/test/ProcessGroupNCCLTest.cpp b/torch/lib/c10d/test/ProcessGroupNCCLTest.cpp index efa96312aba0..92b477fae7de 100644 --- a/torch/lib/c10d/test/ProcessGroupNCCLTest.cpp +++ b/torch/lib/c10d/test/ProcessGroupNCCLTest.cpp @@ -80,7 +80,7 @@ class NCCLTest : public NCCLTestBase { } void wait( - c10::intrusive_ptr& work, + std::shared_ptr& work, std::chrono::milliseconds timeout = kNoTimeout) { at::cuda::CUDAMultiStreamGuard guard(streams_); work->wait(timeout); @@ -166,7 +166,7 @@ class AllreduceNCCLTest : public NCCLTest { AllreduceNCCLTest(const std::string& path, int worldSize) : NCCLTest(path, worldSize) {} - c10::intrusive_ptr run() { + std::shared_ptr run() { // For the duration of this function, make THC use our streams at::cuda::CUDAMultiStreamGuard guard(streams_); @@ -189,7 +189,7 @@ class BroadcastNCCLTest : public NCCLTest { BroadcastNCCLTest(const std::string& path, int worldSize) : NCCLTest(path, worldSize) {} - c10::intrusive_ptr run(int rootRank, int rootTensor) { + std::shared_ptr run(int rootRank, int rootTensor) { // For the duration of this function, make THC use our streams at::cuda::CUDAMultiStreamGuard guard(streams_); @@ -208,7 +208,7 @@ class ReduceNCCLTest : public NCCLTest { ReduceNCCLTest(const std::string& path, int worldSize) : NCCLTest(path, worldSize) {} - c10::intrusive_ptr run(int rootRank, int rootTensor) { + std::shared_ptr run(int rootRank, int rootTensor) { // For the duration of this function, make THC use our streams at::cuda::CUDAMultiStreamGuard guard(streams_); @@ -227,7 +227,7 @@ class AllgatherNCCLTest : public NCCLTest { AllgatherNCCLTest(const std::string& path, int worldSize) : NCCLTest(path, worldSize) {} - c10::intrusive_ptr run() { + std::shared_ptr run() { // For the duration of this function, make THC use our streams at::cuda::CUDAMultiStreamGuard guard(streams_); @@ -242,7 +242,7 @@ struct ReduceScatterNCCLTest : NCCLTest { ReduceScatterNCCLTest(const std::string& path, int worldSize) : NCCLTest(path, worldSize) {} - c10::intrusive_ptr run() { + std::shared_ptr run() { // For the duration of this function, make THC use our streams at::cuda::CUDAMultiStreamGuard guard(streams_); From d1351c66a8da35079dea97131946b0424059467c Mon Sep 17 00:00:00 2001 From: James Reed Date: Wed, 11 Nov 2020 10:54:01 -0800 Subject: [PATCH 14/93] [FX] Add a bunch of docstrings (#47719) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47719 Test Plan: Imported from OSS Reviewed By: zdevito Differential Revision: D24875400 Pulled By: jamesr66a fbshipit-source-id: a1dd43d2eee914a441eff43c4f2efe61a399e8a5 --- torch/fx/graph.py | 167 +++++++++++++++++++++++++++++++++++-- torch/fx/graph_module.py | 19 ++++- torch/fx/node.py | 28 ++++++- torch/fx/symbolic_trace.py | 67 +++++++++++++-- 4 files changed, 264 insertions(+), 17 deletions(-) diff --git a/torch/fx/graph.py b/torch/fx/graph.py index dd07ff7a508e..45e518410145 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -13,7 +13,7 @@ def _shadows_builtin_name(name: str) -> bool: def _is_magic(x: str) -> bool: return x.startswith('__') and x.endswith('__') -def snake_case(s: str) -> str: +def _snake_case(s: str) -> str: return ''.join(['_' + i.lower() if i.isupper() else i for i in s]).lstrip('_') def get_qualified_name(func: Callable[..., Any]) -> str: @@ -108,6 +108,68 @@ def __reversed__(self): return _node_list(self.graph, '_next' if self.direction == '_prev' else '_prev') class Graph: + """ + `Graph` is the main data structure used in the FX Intermediate Representation. + It consists of a series of `Node`s, each representing callsites (or other + syntactic constructs). The list of `Node`s, taken together, constitute a + valid Python function. + + For example, the following code + + ``` + import torch + from torch.fx import symbolic_trace + + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.param = torch.nn.Parameter(torch.rand(3, 4)) + self.linear = torch.nn.Linear(4, 5) + + def forward(self, x): + return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3) + + m = MyModule() + gm = symbolic_trace(m) + ``` + + Will produce the following Graph: + + ``` + print(gm.graph) + ``` + + ``` + graph(x): + %linear_weight : [uses=1] = self.linear.weight + %add_1 : [uses=1] = call_function[target=](args = (%x, %linear_weight), kwargs = {}) + %linear_1 : [uses=1] = call_module[target=linear](args = (%add_1,), kwargs = {}) + %relu_1 : [uses=1] = call_method[target=relu](args = (%linear_1,), kwargs = {}) + %sum_1 : [uses=1] = call_function[target=](args = (%relu_1,), kwargs = {dim: -1}) # noqa: B950 + %topk_1 : [uses=1] = call_function[target=](args = (%sum_1, 3), kwargs = {}) # noqa: B950 + return topk_1 + ``` + + The Node semantics are as follows: + + - `placeholder` represents a function input. The `name` attribute specifies the name this value will take on. + `target` is similarly the name of the argument. `args` and `kwargs` are don't-care. Placeholders correspond to + the function parameters (e.g. `x`) in the graph printout. + - `get_attr` retrieves a parameter from the module hierarchy. `name` is similarly the name the result of the + fetch is assigned to. `target` is the fully-qualified name of the parameter's position in the module hierarchy. + `args` and `kwargs` are don't-care + - `call_function` applies a free function to some values. `name` is similarly the name of the value to assign + to. `target` is the function to be applied. `args` and `kwargs` represent the arguments to the function, + following the Python calling convention + - `call_module` applies a module in the module hierarchy's `forward()` method to given arguments. `name` is + as previous. `target` is the fully-qualified name of the module in the module hierarchy to call. + `args` and `kwargs` represent the arguments to invoke the module on, _including the self argument_. + - `call_method` calls a method on a value. `name` is as similar. `target` is the string name of the method + to apply to the `self` argument. `args` and `kwargs` represent the arguments to invoke the module on, + _including the self argument_. + - `output` contains the output of the traced function in its `args[0]` attribute. This corresponds to the "return" statement + in the Graph printout. + """ def __init__(self): """ Construct an empty Graph. @@ -118,7 +180,13 @@ def __init__(self): self._len = 0 @property - def nodes(self): + def nodes(self) -> _node_list: + """ + Get the list of `Node`s that constitute this Graph. + + Note that this `Node` list representation is a doubly-linked list. Mutations + during iteration (e.g. delete a Node, add a Node) are safe. + """ return _node_list(self) def graph_copy(self, g : 'Graph', val_map : Dict[Node, Node]) -> Optional[Argument]: @@ -156,6 +224,21 @@ def create_node(self, op: str, target: Target, kwargs: Optional[Dict[str, Argument]] = None, name: Optional[str] = None, type_expr: Optional[Any] = None) -> Node: + """ + Create a `Node` and add it to the `Graph` at the current insert-point. + Note that the current insert-point can be set via `Graph.inserting_before` + and `Graph.inserting_after`. + + - op is the opcode for this Node. One of 'call_function', 'call_method', 'get_attr', + 'call_module', 'placeholder', or 'output'. The semantics of these opcodes are + described in the `Graph` docstring. + - args is a tuple of arguments to this node. + - kwargs is a dict from string to argument, representing the kwargs of this Node + - name is an optional string name for the `Node`. This will influence the name + of the value assigned to in the Python generated code. + - type_expr is an optional type annotation representing the Python type + the output of this node will have. + """ assert op in ('call_function', 'call_method', 'get_attr', 'call_module', 'placeholder', 'output') args = () if args is None else args kwargs = {} if kwargs is None else kwargs @@ -224,16 +307,49 @@ def inserting_after(self, n: Optional[Node] = None): # sugar for create_node when you know the op def placeholder(self, name: str, type_expr: Optional[Any] = None) -> Node: + """ + Insert a `placeholder` node into the Graph. A `placeholder` represents + a function input. This function takes a string `name` for the input + value as well as an optional `type_expr`, which is a type expression + describing the type of value this input will take. The type expression + is needed in some cases for proper code generation. + + The same insertion point rules apply for this method as `Graph.create_node`. + """ return self.create_node('placeholder', name, type_expr=type_expr) - def get_attr(self, name: str, type_expr: Optional[Any] = None) -> Node: - return self.create_node('get_attr', name, type_expr=type_expr) + def get_attr(self, qualified_name: str, type_expr: Optional[Any] = None) -> Node: + """ + Insert a `get_attr` node into the Graph. A `get_attr` `Node` represents the + fetch of an attribute from the `Module` hierarchy. `qualified_name` is the + fully-qualified name of the attribute to be retrieved. For example, if + the traced Module has a submodule named `foo`, which has a submodule named + `bar`, which has an attribute named `baz`, the qualified name `foo.bar.baz` + should be passed as `qualified_name`. + + The same insertion point and type expression rules apply for this method + as `Graph.create_node`. + """ + return self.create_node('get_attr', qualified_name, type_expr=type_expr) def call_module(self, module_name: str, args: Optional[Tuple[Argument, ...]] = None, kwargs: Optional[Dict[str, Argument]] = None, type_expr: Optional[Any] = None) -> Node: + """ + Insert a `call_module` `Node` into the `Graph`. A `call_module` node + represents a call to the forward() function of a `Module` in the `Module` + hierarchy. For example, if the traced `Module` has a submodule named `foo`, + which has a submodule named `bar`, the qualified name `foo.bar` should + be passed as `module_name` to call that module. + + `args` and `kwargs` represent the args and kwargs passed to the called + `Module`, respectively. + + The same insertion point and type expression rules apply for this method + as `Graph.create_node`. + """ return self.create_node('call_module', module_name, args, kwargs, type_expr=type_expr) def call_method(self, @@ -241,6 +357,18 @@ def call_method(self, args: Optional[Tuple[Argument, ...]] = None, kwargs: Optional[Dict[str, Argument]] = None, type_expr: Optional[Any] = None) -> Node: + """ + Insert a `call_method` `Node` into the `Graph`. A `call_method` node + represents a call to a given method on the 0th element of `args. + For example, if args[0] is a `Node` representing a `Tensor`, then to call + `relu()` on that `Tensor`, pass `relu` to `method_name`. + + `args` and `kwargs` represent the args and kwargs passed to the called + method, respectively. + + The same insertion point and type expression rules apply for this method + as `Graph.create_node`. + """ return self.create_node('call_method', method_name, args, kwargs, type_expr=type_expr) def call_function(self, @@ -248,10 +376,22 @@ def call_function(self, args: Optional[Tuple[Argument, ...]] = None, kwargs: Optional[Dict[str, Argument]] = None, type_expr: Optional[Any] = None) -> Node: + """ + Insert a `call_function` `Node` into the `Graph`. A `call_function` node + represents a call to a Python callable, specified by `the_function`. `the_function` + can be any PyTorch operator, Python function, or member of the `builtins` + or `operator` namespaces. + + `args` and `kwargs` represent the args and kwargs passed to the called + method, respectively. + + The same insertion point and type expression rules apply for this method + as `Graph.create_node`. + """ return self.create_node('call_function', the_function, args, kwargs, type_expr=type_expr) def node_copy(self, node: Node, arg_transform: Callable[[Node], Argument] = lambda x: x) -> Node: - """ copy a node from one graph into another. arg_transform needs to transform arguments from the graph of node + """ Copy a node from one graph into another. arg_transform needs to transform arguments from the graph of node to the graph of self. Example: g : torch.fx.Graph = ... @@ -281,6 +421,14 @@ def node_copy(self, node: Node, arg_transform: Callable[[Node], Argument] = lamb return self.create_node(node.op, node.target, args, kwargs, name, node.type) def output(self, result: Argument, type_expr: Optional[Any] = None): + """ + Insert an `output` `Node` into the `Graph`. An `output` node represents + a `return` statement in the Python code. `result` is the value that should + be returned. + + The same insertion point and type expression rules apply for this method + as `Graph.create_node`. + """ return self.create_node(op='output', target='output', args=(result,), type_expr=type_expr) def _name(self, target: Target) -> str: @@ -294,7 +442,7 @@ def _name(self, target: Target) -> str: op = op.replace('.', '_') # delete all characters that are illegal in a Python identifier op = re.sub('[^0-9a-zA-Z_]+', '_', op) - op = snake_case(op) + op = _snake_case(op) if op[0].isdigit(): op = f'_{op}' @@ -318,6 +466,9 @@ def _register_name_used(self, op : str) -> str: return f'{op}_{i}' def python_code(self, root_module: str) -> str: + """ + Turn this `Graph` into valid Python code. + """ free_vars: List[str] = [] modules_used : Set[str] = set() body: List[str] = [] @@ -405,6 +556,10 @@ def forward(self, {', '.join(free_vars)}){maybe_return_annotation}: return fn_code def __str__(self) -> str: + """ + Print a human-readable (not machine-readable) string representation + of this Graph + """ placeholder_names : List[str] = [] # This is a one-element array just so `format_node` can modify the closed # over value diff --git a/torch/fx/graph_module.py b/torch/fx/graph_module.py index 4a7ab8cebbc8..3525e180c43f 100644 --- a/torch/fx/graph_module.py +++ b/torch/fx/graph_module.py @@ -172,11 +172,19 @@ def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph): @property def graph(self): + """ + Return the `Graph` underlying this `GraphModule` + """ return self._graph @graph.setter - def graph(self, val) -> None: - self._graph = val + def graph(self, g) -> None: + """ + Set the underlying `Graph` for this `GraphModule`. This will internally + recompile the `GraphModule` so that the generated `forward()` function + corresponds to `g` + """ + self._graph = g self.recompile() def recompile(self) -> None: @@ -204,6 +212,13 @@ def wrapped_call(self, *args, **kwargs): cls.__call__ = wrapped_call def __reduce__(self): + """ + Serialization of GraphModule. We serialize only the generated code, not + the underlying `Graph`. This is because `Graph` does not have on-disk + backward-compatibility guarantees, whereas Python source code does. + On the deserialization side, we symbolically trace through the generated + code to regenerate the underlying `Graph` + """ dict_without_graph = self.__dict__.copy() del dict_without_graph['_graph'] return (deserialize_graphmodule, (dict_without_graph,)) diff --git a/torch/fx/node.py b/torch/fx/node.py index 118b32fe15b0..dd304a801155 100644 --- a/torch/fx/node.py +++ b/torch/fx/node.py @@ -59,10 +59,16 @@ def __init__(self, graph: 'Graph', name: str, op: str, target: Target, @property def next(self) -> 'Node': + """ + Get the next node in the linked list + """ return self._next @property def prev(self) -> 'Node': + """ + Get the previous node in the linked list + """ return self._prev def prepend(self, x: 'Node'): @@ -96,18 +102,38 @@ def _remove_from_list(self): @property def args(self) -> Tuple[Argument, ...]: + """ + Return the tuple of arguments to this Node. The interpretation of arguments + depends on the node's opcode. See the `fx.Graph` docstring for more + information. + """ return self._args @args.setter def args(self, a : Tuple[Argument, ...]): + """ + Set the tuple of arguments to this Node. The interpretation of arguments + depends on the node's opcode. See the `fx.Graph` docstring for more + information. + """ self._update_args_kwargs(map_arg(a, lambda x: x), self._kwargs) # type: ignore @property def kwargs(self) -> Dict[str, Argument]: + """ + Return the dict of kwargs to this Node. The interpretation of arguments + depends on the node's opcode. See the `fx.Graph` docstring for more + information. + """ return self._kwargs @kwargs.setter def kwargs(self, k : Dict[str, Argument]): + """ + Set the dict of kwargs to this Node. The interpretation of arguments + depends on the node's opcode. See the `fx.Graph` docstring for more + information. + """ self._update_args_kwargs(self._args, map_arg(k, lambda x: x)) # type: ignore def _update_args_kwargs(self, new_args : Tuple[Argument, ...], new_kwargs : Dict[str, Argument]): @@ -151,7 +177,7 @@ def maybe_replace_node(n : Node) -> Node: def map_arg(a: Argument, fn: Callable[[Node], Argument]) -> Argument: - """ apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys. """ + """ Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys. """ if isinstance(a, tuple): return tuple(map_arg(elem, fn) for elem in a) if isinstance(a, list): diff --git a/torch/fx/symbolic_trace.py b/torch/fx/symbolic_trace.py index 20566bb58e6e..a75d5ff908f8 100644 --- a/torch/fx/symbolic_trace.py +++ b/torch/fx/symbolic_trace.py @@ -38,10 +38,34 @@ def _patch_function(fn: FunctionType, nargs: int) -> FunctionType: # instead, let's make python think that args and kwargs are normal variables class Tracer(TracerBase): + """ + `Tracer` is the class that implements the symbolic tracing functionality + of `torch.fx.symbolic_trace`. A call to `symbolic_trace(m)` is equivalent + to `Tracer().trace(m)`. + + Tracer can be subclassed to override various behaviors of the tracing + process. The different behaviors that can be overridden are described + in the docstrings of the methods on this class. + """ def __init__(self): super().__init__() def create_arg(self, a: Any) -> Argument: + """ + A method to specify the behavior of tracing when preparing values to + be used as arguments to nodes in the `Graph`. + + By default, the behavior includes: + - Iterate through collection types (e.g. tuple, list, dict) and recursively + call `create_args` on the elements. + - Given a Proxy object, return a reference to the underlying IR `Node` + - Given a non-Proxy Tensor object, emit IR for various cases: + - For a Parameter, emit a `get_attr` node referring to that Parameter + - For a non-Parameter Tensor, store the Tensor away in a special + attribute referring to that attribute. + + This method can be overridden to support more types. + """ # The base tracer is used to construct Graphs when there is no associated # module hierarchy, so it can never create parameter references. # The default tracer adds the ability to refer to parameters when @@ -95,19 +119,43 @@ def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> boo """ return m.__module__.startswith('torch.nn') and not isinstance(m, torch.nn.Sequential) - def path_of_module(self, mod): + def path_of_module(self, mod) -> str: + """ + Helper method to find the qualified name of `mod` in the Module hierarchy + of `root`. For example, if `root` has a submodule named `foo`, which has + a submodule named `bar`, passing `bar` into this function will return + the string "foo.bar". + """ for n, p in self.root.named_modules(): if mod is p: return n raise NameError('module is not installed as a submodule') def call_module(self, m: torch.nn.Module, forward: Callable[..., Any], args, kwargs): + """ + Method that specifies the behavior of this `Tracer` when it encounters + a call to an `nn.Module` instance. + + By default, the behavior is to check if the called module is a leaf module + via `is_leaf_module`. If it is, emit a `call_module` node referring to + `m` in the `Graph`. Otherwise, call the `Module` normally, tracing through + the operations in its `forward` function. + + This method can be overridden to--for example--create nested traced + GraphModules, or any other behavior you would want while tracing across + `Module` boundaries. + """ module_qualified_name = self.path_of_module(m) if not self.is_leaf_module(m, module_qualified_name): return forward(*args, **kwargs) return self.create_proxy('call_module', module_qualified_name, args, kwargs) def create_args_for_root(self, root_fn, is_module): + """ + Create `placeholder` nodes corresponding to the signature of the `root` + Module. This method introspects `root`'s signature and emits those + nodes accordingly, also supporting *args and **kwargs. + """ # In some cases, a function or method has been decorated with a wrapper # defined via `functools.wraps`. In this case, the outer code object # will likely not contain the actual parameters we care about, so unwrap @@ -149,6 +197,10 @@ def proxy_placeholder(name: str): return root_fn, args def trace(self, root: Union[torch.nn.Module, Callable]) -> Graph: + """ + Trace `root` and return the corresponding FX `Graph` representation. `root` + can either be an `nn.Module` instance or a Python callable. + """ if isinstance(root, torch.nn.Module): self.root = root fn = type(root).forward @@ -211,12 +263,11 @@ def forward(*args, **kwargs): return self.graph -# Symbolic tracing API -# -# Given an `nn.Module` or function instance `root`, this function will return a `GraphModule` -# constructed by recording operations seen while tracing through `root`. -# -# Args: -# - root - the `nn.Module` instance to trace def symbolic_trace(root : Union[torch.nn.Module, Callable]) -> GraphModule: + """ + Symbolic tracing API + + Given an `nn.Module` or function instance `root`, this function will return a `GraphModule` + constructed by recording operations seen while tracing through `root`. + """ return GraphModule(root if isinstance(root, torch.nn.Module) else torch.nn.Module(), Tracer().trace(root)) From dbfee42a7db8f6ac01b106e28acfe2c0f4b5f56b Mon Sep 17 00:00:00 2001 From: James Reed Date: Wed, 11 Nov 2020 10:54:01 -0800 Subject: [PATCH 15/93] [FX] Fix uses not updating when erasing a node (#47720) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47720 Test Plan: Imported from OSS Reviewed By: zdevito Differential Revision: D24875880 Pulled By: jamesr66a fbshipit-source-id: aae9ffd10f8085b599e7923152287c6e6950ff49 --- test/test_fx.py | 13 +++++++++++++ torch/fx/graph.py | 9 +++++++++ 2 files changed, 22 insertions(+) diff --git a/test/test_fx.py b/test/test_fx.py index dcb104528402..b035b37663d4 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -700,6 +700,19 @@ def test_graph_fns(self): ref = torch.sin(mod.linear(input) + mod.bias) self.assertEqual(r, ref) + def test_remove_uses(self): + g : torch.fx.Graph = Graph() + x : torch.fx.Node = g.placeholder('x') + relu : torch.fx.Node = g.call_function(torch.relu, (x,)) + neg : torch.fx.Node = g.call_function(torch.neg, (relu,)) + g.output(neg) + + neg.replace_all_uses_with(relu) + g.erase_node(neg) + + self.assertTrue(neg not in relu.users) + + def test_construct_root_dict(self): graph : torch.fx.Graph = torch.fx.Graph() a : torch.fx.Node = graph.create_node('placeholder', 'x') diff --git a/torch/fx/graph.py b/torch/fx/graph.py index 45e518410145..65438471f466 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -261,6 +261,15 @@ def erase_node(self, to_erase : Node): to_erase._erased = True # iterators may retain handles to erased nodes self._len -= 1 + # Null out this Node's argument nodes so that the Nodes referred to + # can update their `users` accordingly + new_args = map_arg(to_erase.args, lambda n: None) + assert isinstance(new_args, tuple) + to_erase.args = new_args + new_kwargs = map_arg(to_erase.kwargs, lambda n: None) + assert isinstance(new_kwargs, dict) + to_erase.kwargs = new_kwargs + def inserting_before(self, n: Optional[Node] = None): """Set the point at which create_node and companion methods will insert into the graph. When used within a 'with' statement, this will temporary set the insert point and From a1db5b0f2bcd96f1e884ff63ebcf2b01f543a77a Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Wed, 11 Nov 2020 10:54:29 -0800 Subject: [PATCH 16/93] Added CUDA support for complex input for torch.inverse #2 (#47595) Summary: `torch.inverse` now works for complex inputs on GPU. Opening a new PR here. The previous PR was merged and reverted due to a bug in tests marked with `slowTest`. Previous PR https://github.com/pytorch/pytorch/pull/45034 Ref. https://github.com/pytorch/pytorch/issues/33152 Pull Request resolved: https://github.com/pytorch/pytorch/pull/47595 Reviewed By: navahgar Differential Revision: D24840955 Pulled By: anjali411 fbshipit-source-id: ec49fffdc4b3cb4ae7507270fa24e127be14f59b --- aten/src/ATen/cuda/CUDABlas.cpp | 82 ++++++++++++ aten/src/ATen/cuda/CUDABlas.h | 8 ++ aten/src/ATen/cuda/CUDASolver.cpp | 98 +++++++++++++++ aten/src/ATen/cuda/CUDASolver.h | 8 ++ .../ATen/native/cuda/BatchLinearAlgebra.cu | 106 +++++++++++++++- .../ATen/native/cuda/BatchLinearAlgebraLib.cu | 7 +- test/test_autograd.py | 4 +- test/test_linalg.py | 119 +++++++++++++++++- test/test_torch.py | 100 +-------------- tools/autograd/derivatives.yaml | 2 +- tools/autograd/gen_variable_type.py | 2 +- torch/_torch_docs.py | 21 +++- .../_internal/common_methods_invocations.py | 103 +++++++-------- 13 files changed, 499 insertions(+), 161 deletions(-) diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp index 26423889caa4..d4b31401f31f 100644 --- a/aten/src/ATen/cuda/CUDABlas.cpp +++ b/aten/src/ATen/cuda/CUDABlas.cpp @@ -586,6 +586,44 @@ void getrfBatched( handle, n, dA_array, ldda, ipiv_array, info_array, batchsize)); } +template <> +void getrfBatched>( + int n, + c10::complex** dA_array, + int ldda, + int* ipiv_array, + int* info_array, + int batchsize) { + auto handle = at::cuda::getCurrentCUDABlasHandle(); + TORCH_CUDABLAS_CHECK(cublasZgetrfBatched( + handle, + n, + reinterpret_cast(dA_array), + ldda, + ipiv_array, + info_array, + batchsize)); +} + +template <> +void getrfBatched>( + int n, + c10::complex** dA_array, + int ldda, + int* ipiv_array, + int* info_array, + int batchsize) { + auto handle = at::cuda::getCurrentCUDABlasHandle(); + TORCH_CUDABLAS_CHECK(cublasCgetrfBatched( + handle, + n, + reinterpret_cast(dA_array), + ldda, + ipiv_array, + info_array, + batchsize)); +} + template <> void getriBatched( int n, double** dA_array, int ldda, int* ipiv_array, int* info_array, int batchsize, double** dC_array) { @@ -602,6 +640,50 @@ void getriBatched( handle, n, dA_array, ldda, ipiv_array, dC_array, n, info_array, batchsize)); } +template <> +void getriBatched>( + int n, + c10::complex** dA_array, + int ldda, + int* ipiv_array, + int* info_array, + int batchsize, + c10::complex** dC_array) { + auto handle = at::cuda::getCurrentCUDABlasHandle(); + TORCH_CUDABLAS_CHECK(cublasZgetriBatched( + handle, + n, + reinterpret_cast(dA_array), + ldda, + ipiv_array, + reinterpret_cast(dC_array), + n, + info_array, + batchsize)); +} + +template <> +void getriBatched>( + int n, + c10::complex** dA_array, + int ldda, + int* ipiv_array, + int* info_array, + int batchsize, + c10::complex** dC_array) { + auto handle = at::cuda::getCurrentCUDABlasHandle(); + TORCH_CUDABLAS_CHECK(cublasCgetriBatched( + handle, + n, + reinterpret_cast(dA_array), + ldda, + ipiv_array, + reinterpret_cast(dC_array), + n, + info_array, + batchsize)); +} + #endif // CUDART_VERSION } // namespace blas diff --git a/aten/src/ATen/cuda/CUDABlas.h b/aten/src/ATen/cuda/CUDABlas.h index 17236dc435db..c5b4c43a27b1 100644 --- a/aten/src/ATen/cuda/CUDABlas.h +++ b/aten/src/ATen/cuda/CUDABlas.h @@ -155,6 +155,10 @@ template<> void getrfBatched(CUDABLAS_GETRF_ARGTYPES(float)); template<> void getrfBatched(CUDABLAS_GETRF_ARGTYPES(double)); +template<> +void getrfBatched>(CUDABLAS_GETRF_ARGTYPES(c10::complex)); +template<> +void getrfBatched>(CUDABLAS_GETRF_ARGTYPES(c10::complex)); #define CUDABLAS_GETRI_ARGTYPES(Dtype) \ @@ -168,6 +172,10 @@ template<> void getriBatched(CUDABLAS_GETRI_ARGTYPES(float)); template<> void getriBatched(CUDABLAS_GETRI_ARGTYPES(double)); +template<> +void getriBatched>(CUDABLAS_GETRI_ARGTYPES(c10::complex)); +template<> +void getriBatched>(CUDABLAS_GETRI_ARGTYPES(c10::complex)); #endif // CUDART_VERSION diff --git a/aten/src/ATen/cuda/CUDASolver.cpp b/aten/src/ATen/cuda/CUDASolver.cpp index 8830fe732fdc..00329acda4a9 100644 --- a/aten/src/ATen/cuda/CUDASolver.cpp +++ b/aten/src/ATen/cuda/CUDASolver.cpp @@ -33,6 +33,56 @@ void getrf( handle, m, n, dA, ldda, static_cast(dataPtr.get()), ipiv, info)); } +template <> +void getrf>( + cusolverDnHandle_t handle, + int m, + int n, + c10::complex* dA, + int ldda, + int* ipiv, + int* info) { + int lwork; + TORCH_CUSOLVER_CHECK(cusolverDnZgetrf_bufferSize( + handle, m, n, reinterpret_cast(dA), ldda, &lwork)); + auto& allocator = *::c10::cuda::CUDACachingAllocator::get(); + void* buffer = allocator.allocate(sizeof(cuDoubleComplex) * lwork).get(); + TORCH_CUSOLVER_CHECK(cusolverDnZgetrf( + handle, + m, + n, + reinterpret_cast(dA), + ldda, + static_cast(buffer), + ipiv, + info)); +} + +template <> +void getrf>( + cusolverDnHandle_t handle, + int m, + int n, + c10::complex* dA, + int ldda, + int* ipiv, + int* info) { + int lwork; + TORCH_CUSOLVER_CHECK(cusolverDnCgetrf_bufferSize( + handle, m, n, reinterpret_cast(dA), ldda, &lwork)); + auto& allocator = *::c10::cuda::CUDACachingAllocator::get(); + void* buffer = allocator.allocate(sizeof(cuComplex) * lwork).get(); + TORCH_CUSOLVER_CHECK(cusolverDnCgetrf( + handle, + m, + n, + reinterpret_cast(dA), + ldda, + static_cast(buffer), + ipiv, + info)); +} + template <> void getrs( cusolverDnHandle_t handle, int n, int nrhs, double* dA, int lda, int* ipiv, double* ret, int ldb, int* info) { @@ -47,6 +97,54 @@ void getrs( handle, CUBLAS_OP_N, n, nrhs, dA, lda, ipiv, ret, ldb, info)); } +template <> +void getrs>( + cusolverDnHandle_t handle, + int n, + int nrhs, + c10::complex* dA, + int lda, + int* ipiv, + c10::complex* ret, + int ldb, + int* info) { + TORCH_CUSOLVER_CHECK(cusolverDnZgetrs( + handle, + CUBLAS_OP_N, + n, + nrhs, + reinterpret_cast(dA), + lda, + ipiv, + reinterpret_cast(ret), + ldb, + info)); +} + +template <> +void getrs>( + cusolverDnHandle_t handle, + int n, + int nrhs, + c10::complex* dA, + int lda, + int* ipiv, + c10::complex* ret, + int ldb, + int* info) { + TORCH_CUSOLVER_CHECK(cusolverDnCgetrs( + handle, + CUBLAS_OP_N, + n, + nrhs, + reinterpret_cast(dA), + lda, + ipiv, + reinterpret_cast(ret), + ldb, + info)); +} + } // namespace solver } // namespace cuda } // namespace at diff --git a/aten/src/ATen/cuda/CUDASolver.h b/aten/src/ATen/cuda/CUDASolver.h index 06609409f177..327c7b824c5e 100644 --- a/aten/src/ATen/cuda/CUDASolver.h +++ b/aten/src/ATen/cuda/CUDASolver.h @@ -19,6 +19,10 @@ template<> void getrf(CUDASOLVER_GETRF_ARGTYPES(float)); template<> void getrf(CUDASOLVER_GETRF_ARGTYPES(double)); +template<> +void getrf>(CUDASOLVER_GETRF_ARGTYPES(c10::complex)); +template<> +void getrf>(CUDASOLVER_GETRF_ARGTYPES(c10::complex)); #define CUDASOLVER_GETRS_ARGTYPES(Dtype) \ @@ -32,6 +36,10 @@ template<> void getrs(CUDASOLVER_GETRS_ARGTYPES(float)); template<> void getrs(CUDASOLVER_GETRS_ARGTYPES(double)); +template<> +void getrs>(CUDASOLVER_GETRS_ARGTYPES(c10::complex)); +template<> +void getrs>(CUDASOLVER_GETRS_ARGTYPES(c10::complex)); } // namespace solver diff --git a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu index 318185e43e8a..c0bc2d915bd0 100644 --- a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu +++ b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu @@ -328,6 +328,18 @@ inline magma_int_t magmaGetriOptimalBlocksize(magma_int_t n) { return magma_get_sgetri_nb(n); } +template <> +inline magma_int_t magmaGetriOptimalBlocksize>( + magma_int_t n) { + return magma_get_zgetri_nb(n); +} + +template <> +inline magma_int_t magmaGetriOptimalBlocksize>( + magma_int_t n) { + return magma_get_cgetri_nb(n); +} + template<> void magmaGetri( magma_int_t n, double* dA, magma_int_t ldda, magma_int_t* ipiv, double* dwork, @@ -346,6 +358,48 @@ void magmaGetri( AT_CUDA_CHECK(cudaGetLastError()); } +template <> +void magmaGetri>( + magma_int_t n, + c10::complex* dA, + magma_int_t ldda, + magma_int_t* ipiv, + c10::complex* dwork, + magma_int_t lwork, + magma_int_t* info) { + MagmaStreamSyncGuard guard; + magma_zgetri_gpu( + n, + reinterpret_cast(dA), + ldda, + ipiv, + reinterpret_cast(dwork), + lwork, + info); + AT_CUDA_CHECK(cudaGetLastError()); +} + +template <> +void magmaGetri>( + magma_int_t n, + c10::complex* dA, + magma_int_t ldda, + magma_int_t* ipiv, + c10::complex* dwork, + magma_int_t lwork, + magma_int_t* info) { + MagmaStreamSyncGuard guard; + magma_cgetri_gpu( + n, + reinterpret_cast(dA), + ldda, + ipiv, + reinterpret_cast(dwork), + lwork, + info); + AT_CUDA_CHECK(cudaGetLastError()); +} + template<> void magmaGetriBatched( magma_int_t n, double** dA_array, magma_int_t ldda, @@ -364,6 +418,54 @@ void magmaGetriBatched( AT_CUDA_CHECK(cudaGetLastError()); } +template <> +void magmaGetriBatched>( + magma_int_t n, + c10::complex** dA_array, + magma_int_t ldda, + magma_int_t** ipiv_array, + c10::complex** dinvA_array, + magma_int_t lddia, + magma_int_t* info_array, + magma_int_t batchsize, + const MAGMAQueue& magma_queue) { + magma_zgetri_outofplace_batched( + n, + reinterpret_cast(dA_array), + ldda, + ipiv_array, + reinterpret_cast(dinvA_array), + lddia, + info_array, + batchsize, + magma_queue.get_queue()); + AT_CUDA_CHECK(cudaGetLastError()); +} + +template <> +void magmaGetriBatched>( + magma_int_t n, + c10::complex** dA_array, + magma_int_t ldda, + magma_int_t** ipiv_array, + c10::complex** dinvA_array, + magma_int_t lddia, + magma_int_t* info_array, + magma_int_t batchsize, + const MAGMAQueue& magma_queue) { + magma_cgetri_outofplace_batched( + n, + reinterpret_cast(dA_array), + ldda, + ipiv_array, + reinterpret_cast(dinvA_array), + lddia, + info_array, + batchsize, + magma_queue.get_queue()); + AT_CUDA_CHECK(cudaGetLastError()); +} + template<> void magmaCholeskySolve( magma_uplo_t uplo, magma_int_t n, magma_int_t nrhs, double* dA, magma_int_t ldda, @@ -1019,14 +1121,14 @@ Tensor _inverse_helper_cuda_legacy(const Tensor& self) { if (self.dim() > 2) { std::vector infos(batchCount(self), 0); auto self_working_copy = cloneBatchedColumnMajor(self); - AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "inverse_cuda", [&]{ + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "inverse_cuda", [&]{ apply_batched_inverse( self_working_copy, self_inv_working_copy, infos); }); batchCheckErrors(infos, "inverse_cuda"); } else { int64_t info = 0; - AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "inverse_cuda", [&]{ + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "inverse_cuda", [&]{ apply_single_inverse(self_inv_working_copy, info); }); singleCheckErrors(info, "inverse_cuda"); diff --git a/aten/src/ATen/native/cuda/BatchLinearAlgebraLib.cu b/aten/src/ATen/native/cuda/BatchLinearAlgebraLib.cu index e1af0ee55876..b8289aae26f9 100644 --- a/aten/src/ATen/native/cuda/BatchLinearAlgebraLib.cu +++ b/aten/src/ATen/native/cuda/BatchLinearAlgebraLib.cu @@ -106,13 +106,14 @@ Tensor _inverse_helper_cuda_lib(const Tensor& self) { if (self.dim() > 2 && batch_size > 1) { Tensor infos = at::zeros({batchCount(self) * 2}, self.options().dtype(kInt)); - AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "inverse_cuda", [&]{ - apply_batched_inverse_lib(self_working_copy, self_inv_working_copy, infos); + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "inverse_cuda", [&]{ + apply_batched_inverse_lib( + self_working_copy, self_inv_working_copy, infos); }); batchCheckErrors(infos, "inverse_cuda", false, 2); } else { Tensor info = at::zeros({2}, self.options().dtype(at::kInt)); - AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "inverse_cuda", [&]{ + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "inverse_cuda", [&]{ apply_single_inverse_lib(self_working_copy, self_inv_working_copy, info); }); batchCheckErrors(info, "inverse_cuda", false, 2); diff --git a/test/test_autograd.py b/test/test_autograd.py index 177a9b4c7805..3003d4c06403 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -5006,10 +5006,10 @@ def run_functional_checks(test_case, test_name, name, apply_fn, run_grad_checks, 'eq_', 'ne_', 'add', '__radd__', 'sum', 'conj', 'sin', 'cos', 'mul', 'sinh', 'cosh', '__rmul__', 'sgn', 'abs', 'dot', 'vdot', 'tensor_split', 'matmul', 'bmm', 'mv', 'ger', 'diagonal', 'atan', 'angle', 'tanh', 'fill_', 'sub', - 'exp', 'mean'] + separate_complex_tests + 'exp', 'mean', 'inverse'] + separate_complex_tests # this list corresponds to cases that are not currently implemented -skip_cuda_list = ['bmm_complex', 'matmul_4d_4d_complex'] +skip_cuda_list = ['bmm_complex', 'matmul_4d_4d_complex', 'inverse_batched_complex'] def add_test( name, diff --git a/test/test_linalg.py b/test/test_linalg.py index cbab1bde6963..d33d3cfe98ce 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -6,10 +6,11 @@ from random import randrange from torch.testing._internal.common_utils import \ - (TestCase, run_tests, TEST_NUMPY, IS_MACOS, IS_WINDOWS, TEST_WITH_ASAN, make_tensor) + (TestCase, run_tests, TEST_NUMPY, IS_MACOS, IS_WINDOWS, slowTest, TEST_WITH_ASAN, make_tensor) from torch.testing._internal.common_device_type import \ (instantiate_device_type_tests, dtypes, dtypesIfCUDA, - onlyCUDA, onlyOnCPUAndCUDA, skipCUDAIfNoMagma, skipCPUIfNoLapack, precisionOverride) + onlyCUDA, skipCUDAIfNoMagma, skipCPUIfNoLapack, precisionOverride, + skipCUDAIfNoMagmaAndNoCusolver, onlyOnCPUAndCUDA) from torch.testing._internal.jit_metaprogramming_utils import gen_script_fn_and_args from torch.autograd import gradcheck @@ -1018,6 +1019,120 @@ def test_nuclear_norm_exceptions_old(self, device): self.assertRaisesRegex(RuntimeError, "duplicate or invalid", torch.norm, x, "nuc", (0, 0)) self.assertRaisesRegex(IndexError, "Dimension out of range", torch.norm, x, "nuc", (0, 2)) + @skipCUDAIfNoMagmaAndNoCusolver + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + @precisionOverride({torch.float32: 2e-3, torch.complex64: 2e-3}) + def test_inverse(self, device, dtype): + from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value + + def run_test(matrix, batches, n): + matrix_inverse = torch.inverse(matrix) + + # Compare against NumPy output + # NumPy uses 'gesv' LAPACK routine solving the equation A A_inv = I + # But in PyTorch 'gertf' + 'getri' is used causing element-wise differences + expected = np.linalg.inv(matrix.cpu().numpy()) + self.assertEqual(matrix_inverse, expected, atol=self.precision, rtol=1e-4) + + # Additional correctness tests, check matrix*matrix_inverse == identity + identity = torch.eye(n, dtype=dtype, device=device) + # TODO(@ivanyashchuk): remove this once batched matmul is avaiable on CUDA for complex dtypes + if self.device_type == 'cuda' and dtype.is_complex: + result_identity_list1 = [] + result_identity_list2 = [] + p = int(np.prod(batches)) # use `p` instead of -1, so that the test works for empty input as well + for m, m_inv in zip(matrix.contiguous().view(p, n, n), matrix_inverse.contiguous().view(p, n, n)): + result_identity_list1.append(torch.matmul(m, m_inv)) + result_identity_list2.append(torch.matmul(m_inv, m)) + result_identity1 = torch.stack(result_identity_list1).view(*batches, n, n) + result_identity2 = torch.stack(result_identity_list2).view(*batches, n, n) + self.assertEqual(identity.expand_as(matrix), result_identity1) + self.assertEqual(identity.expand_as(matrix), result_identity2) + else: + self.assertEqual(identity.expand_as(matrix), torch.matmul(matrix, matrix_inverse)) + self.assertEqual(identity.expand_as(matrix), torch.matmul(matrix_inverse, matrix)) + + # check the out= variant + matrix_inverse_out = torch.empty(*batches, n, n, dtype=dtype, device=device) + ans = torch.inverse(matrix, out=matrix_inverse_out) + self.assertEqual(matrix_inverse_out, ans, atol=0, rtol=0) + self.assertEqual(matrix_inverse_out, matrix_inverse, atol=0, rtol=0) + + # batched matrices: 3+ dimensional tensors, check matrix_inverse same as single-inverse for each matrix + if matrix.ndim > 2: + expected_inv_list = [] + p = int(np.prod(batches)) # use `p` instead of -1, so that the test works for empty input as well + for mat in matrix.contiguous().view(p, n, n): + expected_inv_list.append(torch.inverse(mat)) + expected_inv = torch.stack(expected_inv_list).view(*batches, n, n) + if self.device_type == 'cuda' and dtype in [torch.float32, torch.complex64]: + # single-inverse is done using cuSOLVER, while batched inverse is done using MAGMA + # individual values can be significantly different for fp32, hence rather high rtol is used + # the important thing is that torch.inverse passes above checks with identity + self.assertEqual(matrix_inverse, expected_inv, atol=1e-1, rtol=1e-2) + else: + self.assertEqual(matrix_inverse, expected_inv) + + for batches, n in itertools.product( + [[], [1], [4], [2, 3]], + [0, 5, 64] + ): + # large batch size and large matrix size will be tested in test_inverse_many_batches (slow test) + if batches and batches[0] == 32 and n == 256: + continue + matrices = random_fullrank_matrix_distinct_singular_value(n, *batches, dtype=dtype).to(device) + run_test(matrices, batches, n) + + # test non-contiguous input + run_test(matrices.transpose(-2, -1), batches, n) + if n > 0: + run_test( + random_fullrank_matrix_distinct_singular_value(n * 2, *batches, dtype=dtype).to(device) + .view(-1, n * 2, n * 2)[:, ::2, ::2].view(*batches, n, n), + batches, n + ) + + @slowTest + @skipCUDAIfNoMagmaAndNoCusolver + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + @precisionOverride({torch.float32: 2e-3, torch.complex64: 2e-3, + torch.float64: 1e-5, torch.complex128: 1e-5}) + def test_inverse_many_batches(self, device, dtype): + from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value + + def test_inverse_many_batches_helper(b, n): + matrices = random_fullrank_matrix_distinct_singular_value(b, n, n, dtype=dtype).to(device) + matrices_inverse = torch.inverse(matrices) + + # Compare against NumPy output + expected = np.linalg.inv(matrices.cpu().numpy()) + self.assertEqual(matrices_inverse, expected, atol=self.precision, rtol=1e-4) + + test_inverse_many_batches_helper(5, 256) + test_inverse_many_batches_helper(3, 512) + test_inverse_many_batches_helper(64, 64) + + @skipCUDAIfNoMagmaAndNoCusolver + @skipCPUIfNoLapack + @onlyOnCPUAndCUDA # TODO: XLA doesn't raise exception + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + def test_inverse_errors(self, device, dtype): + # inverse expects batches of square matrices as input + with self.assertRaisesRegex(RuntimeError, "must be batches of square matrices"): + torch.inverse(torch.randn(2, 3, 4, 3)) + + # if input is not invertible, RuntimeError is raised mentioning the first non-invertible batch + def run_test_singular_input(batch_dim, n): + x = torch.eye(3, 3, dtype=dtype, device=device).reshape((1, 3, 3)).repeat(batch_dim, 1, 1) + x[n, -1, -1] = 0 + with self.assertRaisesRegex(RuntimeError, rf'For batch {n}: U\(3,3\) is zero'): + torch.inverse(x) + + for params in [(1, 0), (2, 0), (2, 1), (4, 0), (4, 2), (10, 2)]: + run_test_singular_input(*params) + @skipCUDAIfNoMagma @skipCPUIfNoLapack @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) diff --git a/test/test_torch.py b/test/test_torch.py index fce680b2b7af..143dadf09ffe 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -35,7 +35,7 @@ wrapDeterministicFlagAPITest, make_tensor) from multiprocessing.reduction import ForkingPickler from torch.testing._internal.common_device_type import instantiate_device_type_tests, \ - skipCPUIfNoLapack, skipCUDAIfNoMagma, skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfRocm, skipCUDAIfNotRocm, \ + skipCPUIfNoLapack, skipCUDAIfNoMagma, skipCUDAIfRocm, skipCUDAIfNotRocm, \ onlyCUDA, onlyCPU, \ dtypes, dtypesIfCUDA, dtypesIfCPU, deviceCountAtLeast, skipCUDAIf, precisionOverride, \ PYTORCH_CUDA_MEMCHECK, largeTensorTest, onlyOnCPUAndCUDA, expectedAlertNondeterministic @@ -6185,88 +6185,6 @@ def test_pow(self, device): torch.pow(m1, 1, out=out) self.assertEqual(out, m1) - @skipCUDAIfNoMagmaAndNoCusolver - @skipCPUIfNoLapack - def test_inverse(self, device): - from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value - - def test_inverse_helper(matrix, batches, n): - identity = torch.eye(n, dtype=torch.float64, device=device) - - # correctness test, check matrix*matrix_inverse == identity - matrix_inverse = torch.inverse(matrix) - - self.assertEqual(identity.expand_as(matrix), torch.matmul(matrix, matrix_inverse), atol=1e-8, rtol=0) - self.assertEqual(identity.expand_as(matrix), torch.matmul(matrix_inverse, matrix), atol=1e-8, rtol=0) - - # torch.inverse with out and batches - matrix_inverse_out = torch.empty(*batches, n, n, dtype=torch.float64, device=device) - torch.inverse(matrix, out=matrix_inverse_out) - self.assertEqual(matrix_inverse_out, matrix_inverse, atol=0, rtol=0) - - # batched matrices: 3+ dimensional tensors, check matrix_inverse same as single-inverse for each matrix - if matrix.ndim > 2: - expected_inv_list = [] - for mat in matrix.contiguous().view(-1, n, n): - expected_inv_list.append(torch.inverse(mat)) - expected_inv = torch.stack(expected_inv_list).view(*batches, n, n) - self.assertEqual(matrix_inverse, expected_inv) - - for batches, n in product( - [[], [1], [4], [2, 3], [32]], - [5, 256] - ): - # large batch size and large matrix size will be tested in test_inverse_many_batches (slow test) - if batches and batches[0] == 32 and n == 256: - continue - _matrices = random_fullrank_matrix_distinct_singular_value(n, *batches).to(device) - test_inverse_helper(_matrices, batches, n) - test_inverse_helper(_matrices.transpose(-2, -1), batches, n) - test_inverse_helper( - random_fullrank_matrix_distinct_singular_value(n * 2, *batches).to(device) - .view(-1, n * 2, n * 2)[:, ::2, ::2].view(*batches, n, n), - batches, n - ) - - # incorrect input test - with self.assertRaisesRegex(RuntimeError, "must be batches of square matrices"): - torch.inverse(torch.randn(2, 3, 4, 3)) - - # test for zero-sized tensor - def test_inverse_helper_zero_size(size): - data = torch.zeros(*size, device=device) - out = torch.inverse(data) - self.assertTrue(out.size() == data.size()) - - test_inverse_helper_zero_size([0, 0]) - test_inverse_helper_zero_size([3, 0, 0]) - test_inverse_helper_zero_size([0, 3, 3]) - - # non-contiguous inputs - if not TEST_NUMPY: - return - - from numpy.linalg import inv - matrices = random_fullrank_matrix_distinct_singular_value(3, 2).to(device).permute(0, 2, 1) - assert not matrices.is_contiguous() - matrices_inverse = torch.inverse(matrices) - expected_inv = torch.as_tensor(inv(matrices.cpu().numpy())) - self.assertEqual(matrices_inverse, expected_inv.to(device)) - - @skipCUDAIfNoMagmaAndNoCusolver - @skipCPUIfNoLapack - @onlyOnCPUAndCUDA # TODO: XLA doesn't raise exception - def test_inverse_singular(self, device): - def helper(batch_dim, n): - x = torch.eye(3, 3, dtype=torch.float, device=device).reshape((1, 3, 3)).repeat(batch_dim, 1, 1) - x[n, -1, -1] = 0 - - with self.assertRaisesRegex(RuntimeError, rf'For batch {n}: U\(3,3\) is zero'): - torch.inverse(x) - - for params in [(1, 0), (2, 0), (2, 1), (4, 0), (4, 2), (10, 2)]: - helper(*params) - @unittest.skipIf(not TEST_NUMPY, 'NumPy not found') @onlyOnCPUAndCUDA @dtypes(torch.int8, torch.int16, torch.int32, torch.int64) @@ -7011,22 +6929,6 @@ def test_is_set_to(self, device): self.assertFalse(t1.is_set_to(t2)) self.assertFalse(t2.is_set_to(t1)) - @slowTest - @skipCUDAIfNoMagmaAndNoCusolver - @skipCPUIfNoLapack - def test_inverse_many_batches(self, device): - from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value - - def test_inverse_many_batches_helper(b, n): - matrices = random_fullrank_matrix_distinct_singular_value(b, n, n).to(device) - matrices_inverse = torch.inverse(matrices) - self.assertEqual(torch.matmul(matrices_inverse, matrices), - torch.eye(b, dtype=torch.float64, device=device).expand_as(matrices)) - - test_inverse_many_batches_helper(5, 256) - test_inverse_many_batches_helper(3, 512) - test_inverse_many_batches_helper(64, 64) - @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3}) @skipCUDAIfNoMagma @skipCPUIfNoLapack diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 49633ca17733..2ffd28400481 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -587,7 +587,7 @@ index: non_differentiable - name: inverse(Tensor self) -> Tensor - self: -at::matmul(result.transpose(-2, -1), at::matmul(grad, result.transpose(-2, -1))) + self: -at::matmul(result.conj().transpose(-2, -1), at::matmul(grad, result.conj().transpose(-2, -1))) - name: isnan(Tensor self) -> Tensor self: non_differentiable diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index c305be0bae0f..b9215a66b098 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -94,7 +94,7 @@ 'dot', 'vdot', 'cholesky', 'triangular_solve', 'mm', '_unsafe_view', 'mv', 'ger', 'bmm', 'diagonal', 'alias', 'atan', 'log', 'log10', 'log1p', 'log2', 'reciprocal', 'tan', 'pow', 'rsqrt', 'tanh', 'tanh_backward', 'asinh', 'acosh', 'take', 'fill_', - 'exp', 'nonzero', 'mean' + 'exp', 'nonzero', 'mean', 'inverse' } # Some operators invalidate the grad_accumulator. Let's reset it. diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index a8b45284187c..fe313a329a81 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -3500,6 +3500,8 @@ def merge_dicts(*dicts): of 2D square tensors, in which case this function would return a tensor composed of individual inverses. +Supports real and complex input. + .. note:: Irrespective of the original strides, the returned tensors will be @@ -3512,7 +3514,7 @@ def merge_dicts(*dicts): Keyword args: {out} -Example:: +Examples:: >>> x = torch.rand(4, 4) >>> y = torch.inverse(x) @@ -3524,12 +3526,29 @@ def merge_dicts(*dicts): [ 0.0000, -0.0000, -0.0000, 1.0000]]) >>> torch.max(torch.abs(z - torch.eye(4))) # Max non-zero tensor(1.1921e-07) + >>> # Batched inverse example >>> x = torch.randn(2, 3, 4, 4) >>> y = torch.inverse(x) >>> z = torch.matmul(x, y) >>> torch.max(torch.abs(z - torch.eye(4).expand_as(x))) # Max non-zero tensor(1.9073e-06) + + >>> x = torch.rand(4, 4, dtype=torch.cdouble) + >>> y = torch.inverse(x) + >>> z = torch.mm(x, y) + >>> z + tensor([[ 1.0000e+00+0.0000e+00j, -1.3878e-16+3.4694e-16j, + 5.5511e-17-1.1102e-16j, 0.0000e+00-1.6653e-16j], + [ 5.5511e-16-1.6653e-16j, 1.0000e+00+6.9389e-17j, + 2.2204e-16-1.1102e-16j, -2.2204e-16+1.1102e-16j], + [ 3.8858e-16-1.2490e-16j, 2.7756e-17+3.4694e-17j, + 1.0000e+00+0.0000e+00j, -4.4409e-16+5.5511e-17j], + [ 4.4409e-16+5.5511e-16j, -3.8858e-16+1.8041e-16j, + 2.2204e-16+0.0000e+00j, 1.0000e+00-3.4694e-16j]], + dtype=torch.complex128) + >>> torch.max(torch.abs(z - torch.eye(4, dtype=torch.cdouble))) # Max non-zero + tensor(7.5107e-16, dtype=torch.float64) """.format(**common_args)) add_docstr(torch.isinf, r""" diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index c409b5265a67..3426390256af 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -1061,11 +1061,11 @@ def method_tests(): ('matrix_power', (S, S, S), [3], "n=3"), ('matrix_power', (S, S, S), [1], "n=1"), ('matrix_power', (S, S, S), [0], "n=0"), - ('matrix_power', lambda: random_fullrank_matrix_distinct_singular_value(S), [-1], "n=-1", (), + ('matrix_power', lambda dtype, device: random_fullrank_matrix_distinct_singular_value(S), [-1], "n=-1", (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('matrix_power', lambda: random_fullrank_matrix_distinct_singular_value(S), [-3], "n=-3", (), + ('matrix_power', lambda dtype, device: random_fullrank_matrix_distinct_singular_value(S), [-3], "n=-3", (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('matrix_power', lambda: random_fullrank_matrix_distinct_singular_value(S, S), [-2], "n=-2", (), + ('matrix_power', lambda dtype, device: random_fullrank_matrix_distinct_singular_value(S, S), [-2], "n=-2", (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), ('matrix_exp', (S, S), NO_ARGS, "single_matrix"), ('matrix_exp', (S, S, S), NO_ARGS, "batch_of_matrices"), @@ -1207,104 +1207,107 @@ def method_tests(): ('index_fill', (S, S), (0, torch.tensor(0, dtype=torch.int64), 2), 'scalar_index_dim', (), [0]), ('index_fill', (), (0, torch.tensor([0], dtype=torch.int64), 2), 'scalar_input_dim', (), [0]), ('index_fill', (), (0, torch.tensor(0, dtype=torch.int64), 2), 'scalar_both_dim', (), [0]), - ('inverse', lambda: random_fullrank_matrix_distinct_singular_value(S), + ('inverse', lambda dtype, device: random_fullrank_matrix_distinct_singular_value(S, dtype=dtype).to(device), NO_ARGS, '', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('inverse', lambda: random_fullrank_matrix_distinct_singular_value(S, 2, 3), + ('inverse', lambda dtype, device: random_fullrank_matrix_distinct_singular_value(S, 2, 3, dtype=dtype).to(device), NO_ARGS, 'batched', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), ('det', (S, S), NO_ARGS, '', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), ('det', (1, 1), NO_ARGS, '1x1', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('det', lambda: random_symmetric_matrix(S), NO_ARGS, 'symmetric', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('det', lambda: random_symmetric_psd_matrix(S), + ('det', lambda dtype, device: random_symmetric_matrix(S), NO_ARGS, 'symmetric', (), + NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), + ('det', lambda dtype, device: random_symmetric_psd_matrix(S), NO_ARGS, 'symmetric_psd', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('det', lambda: random_symmetric_pd_matrix(S), + ('det', lambda dtype, device: random_symmetric_pd_matrix(S), NO_ARGS, 'symmetric_pd', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('det', lambda: random_square_matrix_of_rank(S, S - 2), + ('det', lambda dtype, device: random_square_matrix_of_rank(S, S - 2), NO_ARGS, 'dim2_null', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('det', lambda: random_square_matrix_of_rank(S, 1), NO_ARGS, 'rank1', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('det', lambda: random_square_matrix_of_rank(S, 2), NO_ARGS, 'rank2', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('det', lambda: random_fullrank_matrix_distinct_singular_value(S), NO_ARGS, + ('det', lambda dtype, device: random_square_matrix_of_rank(S, 1), NO_ARGS, 'rank1', (), + NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), + ('det', lambda dtype, device: random_square_matrix_of_rank(S, 2), NO_ARGS, 'rank2', (), + NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), + ('det', lambda dtype, device: random_fullrank_matrix_distinct_singular_value(S), NO_ARGS, 'distinct_singular_values', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), ('det', (3, 3, S, S), NO_ARGS, 'batched', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), ('det', (3, 3, 1, 1), NO_ARGS, 'batched_1x1', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('det', lambda: random_symmetric_matrix(S, 3), + ('det', lambda dtype, device: random_symmetric_matrix(S, 3), NO_ARGS, 'batched_symmetric', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('det', lambda: random_symmetric_psd_matrix(S, 3), + ('det', lambda dtype, device: random_symmetric_psd_matrix(S, 3), NO_ARGS, 'batched_symmetric_psd', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('det', lambda: random_symmetric_pd_matrix(S, 3), + ('det', lambda dtype, device: random_symmetric_pd_matrix(S, 3), NO_ARGS, 'batched_symmetric_pd', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('det', lambda: random_fullrank_matrix_distinct_singular_value(S, 3, 3), NO_ARGS, + ('det', lambda dtype, device: random_fullrank_matrix_distinct_singular_value(S, 3, 3), NO_ARGS, 'batched_distinct_singular_values', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), # For `logdet` and `slogdet`, the function at det=0 is not smooth. # We need to exclude tests with det=0 (e.g. dim2_null, rank1, rank2) and use # `make_nonzero_det` to make the random matrices have nonzero det. For # `logdet`, we also set `make_nonzero_det(matrix, sign=1)` to make the # matrix have positive det. - ('logdet', lambda: make_nonzero_det(torch.randn(S, S), 1), + ('logdet', lambda dtype, device: make_nonzero_det(torch.randn(S, S), 1), NO_ARGS, '', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('logdet', lambda: make_nonzero_det(torch.randn(1, 1), 1), + ('logdet', lambda dtype, device: make_nonzero_det(torch.randn(1, 1), 1), NO_ARGS, '1x1', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('logdet', lambda: make_nonzero_det(random_symmetric_matrix(S), 1), NO_ARGS, + ('logdet', lambda dtype, device: make_nonzero_det(random_symmetric_matrix(S), 1), NO_ARGS, 'symmetric', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('logdet', lambda: make_nonzero_det(random_symmetric_pd_matrix(S), 1), NO_ARGS, + ('logdet', lambda dtype, device: make_nonzero_det(random_symmetric_pd_matrix(S), 1), NO_ARGS, 'symmetric_pd', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('logdet', lambda: make_nonzero_det(random_fullrank_matrix_distinct_singular_value(S), 1, 0), NO_ARGS, + ('logdet', lambda dtype, device: make_nonzero_det(random_fullrank_matrix_distinct_singular_value(S), 1, 0), NO_ARGS, 'distinct_singular_values', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('logdet', lambda: make_nonzero_det(torch.randn(3, 3, S, S), 1), + ('logdet', lambda dtype, device: make_nonzero_det(torch.randn(3, 3, S, S), 1), NO_ARGS, 'batched', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('logdet', lambda: make_nonzero_det(torch.randn(3, 3, 1, 1), 1), + ('logdet', lambda dtype, device: make_nonzero_det(torch.randn(3, 3, 1, 1), 1), NO_ARGS, 'batched_1x1', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('logdet', lambda: make_nonzero_det(random_symmetric_matrix(S, 3), 1), NO_ARGS, + ('logdet', lambda dtype, device: make_nonzero_det(random_symmetric_matrix(S, 3), 1), NO_ARGS, 'batched_symmetric', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('logdet', lambda: make_nonzero_det(random_symmetric_pd_matrix(S, 3), 1), NO_ARGS, + ('logdet', lambda dtype, device: make_nonzero_det(random_symmetric_pd_matrix(S, 3), 1), NO_ARGS, 'batched_symmetric_pd', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('logdet', lambda: make_nonzero_det(random_fullrank_matrix_distinct_singular_value(S, 3), 1, 0), NO_ARGS, + ('logdet', lambda dtype, device: make_nonzero_det(random_fullrank_matrix_distinct_singular_value(S, 3), 1, 0), NO_ARGS, 'batched_distinct_singular_values', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('slogdet', lambda: make_nonzero_det(torch.randn(1, 1), 1), NO_ARGS, + ('slogdet', lambda dtype, device: make_nonzero_det(torch.randn(1, 1), 1), NO_ARGS, '1x1_pos_det', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma], itemgetter(1)), - ('slogdet', lambda: make_nonzero_det(torch.randn(1, 1), -1), NO_ARGS, + ('slogdet', lambda dtype, device: make_nonzero_det(torch.randn(1, 1), -1), NO_ARGS, '1x1_neg_det', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma], itemgetter(1)), - ('slogdet', lambda: make_nonzero_det(torch.randn(S, S), 1), NO_ARGS, + ('slogdet', lambda dtype, device: make_nonzero_det(torch.randn(S, S), 1), NO_ARGS, 'pos_det', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma], itemgetter(1)), - ('slogdet', lambda: make_nonzero_det(torch.randn(S, S), -1), NO_ARGS, + ('slogdet', lambda dtype, device: make_nonzero_det(torch.randn(S, S), -1), NO_ARGS, 'neg_det', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma], itemgetter(1)), - ('slogdet', lambda: make_nonzero_det(random_symmetric_matrix(S)), NO_ARGS, + ('slogdet', lambda dtype, device: make_nonzero_det(random_symmetric_matrix(S)), NO_ARGS, 'symmetric', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma], itemgetter(1)), - ('slogdet', lambda: random_symmetric_pd_matrix(S), NO_ARGS, + ('slogdet', lambda dtype, device: random_symmetric_pd_matrix(S), NO_ARGS, 'symmetric_pd', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma], itemgetter(1)), - ('slogdet', lambda: random_fullrank_matrix_distinct_singular_value(S), NO_ARGS, + ('slogdet', lambda dtype, device: random_fullrank_matrix_distinct_singular_value(S), NO_ARGS, 'distinct_singular_values', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma], itemgetter(1)), - ('slogdet', lambda: make_nonzero_det(torch.randn(3, 3, 1, 1), -1), NO_ARGS, + ('slogdet', lambda dtype, device: make_nonzero_det(torch.randn(3, 3, 1, 1), -1), NO_ARGS, 'batched_1x1_neg_det', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma], itemgetter(1)), - ('slogdet', lambda: make_nonzero_det(torch.randn(3, 3, S, S), 1), NO_ARGS, + ('slogdet', lambda dtype, device: make_nonzero_det(torch.randn(3, 3, S, S), 1), NO_ARGS, 'batched_pos_det', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma], itemgetter(1)), - ('slogdet', lambda: make_nonzero_det(random_symmetric_matrix(S, 3)), NO_ARGS, + ('slogdet', lambda dtype, device: make_nonzero_det(random_symmetric_matrix(S, 3)), NO_ARGS, 'batched_symmetric', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma], itemgetter(1)), - ('slogdet', lambda: random_symmetric_pd_matrix(S, 3), NO_ARGS, + ('slogdet', lambda dtype, device: random_symmetric_pd_matrix(S, 3), NO_ARGS, 'batched_symmetric_pd', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma], itemgetter(1)), - ('slogdet', lambda: random_fullrank_matrix_distinct_singular_value(S, 3), NO_ARGS, + ('slogdet', lambda dtype, device: random_fullrank_matrix_distinct_singular_value(S, 3), NO_ARGS, 'batched_distinct_singular_values', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma], itemgetter(1)), - ('svd', lambda: random_fullrank_matrix_distinct_singular_value(S), + ('svd', lambda dtype, device: random_fullrank_matrix_distinct_singular_value(S), NO_ARGS, '', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('svd', lambda: random_fullrank_matrix_distinct_singular_value(S)[:(S - 2)], NO_ARGS, + ('svd', lambda dtype, device: random_fullrank_matrix_distinct_singular_value(S)[:(S - 2)], NO_ARGS, 'wide', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('svd', lambda: random_fullrank_matrix_distinct_singular_value(S)[:, :(S - 2)], NO_ARGS, + ('svd', lambda dtype, device: random_fullrank_matrix_distinct_singular_value(S)[:, :(S - 2)], NO_ARGS, 'tall', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('svd', lambda: random_fullrank_matrix_distinct_singular_value(S)[:(S - 2)], (False,), + ('svd', lambda dtype, device: random_fullrank_matrix_distinct_singular_value(S)[:(S - 2)], (False,), 'wide_all', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma], lambda usv: (usv[0], usv[1], usv[2][:, :(S - 2)])), - ('svd', lambda: random_fullrank_matrix_distinct_singular_value(S)[:, :(S - 2)], (False,), + ('svd', lambda dtype, device: random_fullrank_matrix_distinct_singular_value(S)[:, :(S - 2)], (False,), 'tall_all', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma], lambda usv: (usv[0][:, :(S - 2)], usv[1], usv[2])), - ('svd', lambda: random_fullrank_matrix_distinct_singular_value(M), NO_ARGS, + ('svd', lambda dtype, device: random_fullrank_matrix_distinct_singular_value(M), NO_ARGS, 'large', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('svd', lambda: random_fullrank_matrix_distinct_singular_value(S, 3), NO_ARGS, + ('svd', lambda dtype, device: random_fullrank_matrix_distinct_singular_value(S, 3), NO_ARGS, 'batched', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('svd', lambda: random_fullrank_matrix_distinct_singular_value(S, 3)[..., :(S - 2), :], NO_ARGS, + ('svd', lambda dtype, device: random_fullrank_matrix_distinct_singular_value(S, 3)[..., :(S - 2), :], NO_ARGS, 'wide_batched', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('svd', lambda: random_fullrank_matrix_distinct_singular_value(S, 3)[..., :, :(S - 2)], NO_ARGS, + ('svd', lambda dtype, device: random_fullrank_matrix_distinct_singular_value(S, 3)[..., :, :(S - 2)], NO_ARGS, 'tall_batched', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('svd', lambda: random_fullrank_matrix_distinct_singular_value(S, 3, 3)[..., :(S - 2), :], (False,), + ('svd', lambda dtype, device: random_fullrank_matrix_distinct_singular_value(S, 3, 3)[..., :(S - 2), :], (False,), 'wide_all_batched', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma], lambda usv: (usv[0], usv[1], usv[2][..., :, :(S - 2)])), - ('svd', lambda: random_fullrank_matrix_distinct_singular_value(S, 3, 3)[..., :, :(S - 2)], (False,), + ('svd', lambda dtype, device: random_fullrank_matrix_distinct_singular_value(S, 3, 3)[..., :, :(S - 2)], (False,), 'tall_all_batched', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma], lambda usv: (usv[0][..., :, :(S - 2)], usv[1], usv[2])), ('qr', (S, S), (False,), 'square_single', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), @@ -1525,7 +1528,7 @@ def maybe_non_contig(tensor): v.requires_grad = requires_grad and (v.is_floating_point() or v.is_complex()) return v elif callable(arg): - return map_arg(arg()) + return map_arg(arg(dtype=dtype, device=device)) else: return arg args_out = tuple(map_arg(arg) for arg in call_args) From 513f62b45ba18edff9cf034d09bd12803485db40 Mon Sep 17 00:00:00 2001 From: Rong Rong Date: Wed, 11 Nov 2020 11:45:34 -0800 Subject: [PATCH 17/93] [hotfix] fix collect_env not working when torch compile/install fails (#47752) Summary: fix collect env not working when pytorch compile from source failed mid-way. ``` Traceback (most recent call last): OSError: /home/rongr/local/pytorch/torch/lib/libtorch_global_deps.so: cannot open shared object file: No such file or directory ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/47752 Reviewed By: janeyx99 Differential Revision: D24888576 Pulled By: walterddr fbshipit-source-id: 3b20daeddbb4118491fb0cca9fb59d861f683da7 --- torch/utils/collect_env.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/utils/collect_env.py b/torch/utils/collect_env.py index 467cc0640ff0..aabbfc843cd0 100644 --- a/torch/utils/collect_env.py +++ b/torch/utils/collect_env.py @@ -10,7 +10,7 @@ try: import torch TORCH_AVAILABLE = True -except (ImportError, NameError, AttributeError): +except (ImportError, NameError, AttributeError, OSError): TORCH_AVAILABLE = False # System Environment Information From 4078f4466803b225de19bc4bac4ea479397f2814 Mon Sep 17 00:00:00 2001 From: Peiyao Zhou Date: Wed, 11 Nov 2020 12:01:22 -0800 Subject: [PATCH 18/93] [TB][embedding supporting] Modify histogram to accept multipy types to skip Castop and avoid OOMing in Castop Summary: To support min/max/mean/std, SummarizeOp need to skip size checking (similar to the LpNorm error mentioned above) and accept multiple types Test Plan: unit test: `buck test //caffe2/caffe2/fb/tensorboard/tests:tensorboard_accumulate_histogram_op_test` https://our.intern.facebook.com/intern/testinfra/testrun/1407375057859572 `buck test //caffe2/caffe2/fb/tensorboard/tests:tensorboard_accumulate_histogram_op_test --stress-runs 1000` https://our.intern.facebook.com/intern/testinfra/testrun/2533274832166362 Reviewed By: cryptopic Differential Revision: D24605507 fbshipit-source-id: fa08372d7c9970083c38abd432d4c86e84fb10e0 --- caffe2/python/hypothesis_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/caffe2/python/hypothesis_test.py b/caffe2/python/hypothesis_test.py index 9298134f651c..9e9c68502a09 100644 --- a/caffe2/python/hypothesis_test.py +++ b/caffe2/python/hypothesis_test.py @@ -2715,7 +2715,7 @@ def histogram(X): Y[X >= upper_bound] = num_buckets + 1 Y[(X >= lower_bound) & (X < upper_bound)] = \ ((X[(X >= lower_bound) & (X < upper_bound)] - lower_bound) / - segment + 1).astype(np.int32) + segment + 1).astype(np.int32) for i in range(Y.shape[0]): for j in range(Y.shape[1]): From da2e2336b673d63a49741901538ec8b9e25c32a6 Mon Sep 17 00:00:00 2001 From: Negin Raoof Date: Wed, 11 Nov 2020 12:04:49 -0800 Subject: [PATCH 19/93] [ONNX] Export and shape inference for prim uninitialized in If subblock (#46094) Summary: Enable export of prim::Uninitialized in If subblock outputs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/46094 Reviewed By: houseroad Differential Revision: D24838537 Pulled By: bzinodev fbshipit-source-id: d0719b140393595e6df114ef5cc1bb845e919c14 --- test/onnx/test_pytorch_onnx_onnxruntime.py | 36 ++++- .../passes/onnx/fixup_onnx_controlflow.cpp | 124 ++++++++++++++++-- torch/onnx/utils.py | 3 +- 3 files changed, 150 insertions(+), 13 deletions(-) diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index 64bf20742d15..609c2e75f330 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -4119,7 +4119,39 @@ def run(): self.assertEqual('Unsupported: ONNX export of Pad in opset 9. The sizes of the padding must be constant. ' + 'Please try opset version 11.', the_exception.args[0]) - @disableScriptTest() # export prim::Uninitialized + @skipIfUnsupportedMinOpsetVersion(11) + @skipIfONNXShapeInference(False) + def test_uninitialized(self): + class UninitializedModel(torch.nn.Module): + def forward(self, y): + if y.shape[1] < 5: + if y.size(0) == 1: + y = y + 4 + else: + return y + return y + + x = torch.ones((3, 4), dtype=torch.int) + self.run_test(UninitializedModel(), x) + + @skipIfUnsupportedMinOpsetVersion(11) + @skipIfONNXShapeInference(False) + def test_uninitialized_dynamic(self): + class UninitializedModel(torch.nn.Module): + def forward(self, y): + if y.shape[1] < 5: + if y.size(0) == 1: + y = y + 4 + else: + return y + return y + + x = torch.ones((3, 4), dtype=torch.int) + y = torch.ones((6, 7), dtype=torch.int) + self.run_test(UninitializedModel(), x, test_with_inputs=[y], + input_names=['input_1'], + dynamic_axes={'input_1': [0, 1]}) + def test_reflection_pad(self): model = torch.nn.ReflectionPad1d(2) x = torch.randn(2, 4, 4) @@ -4129,7 +4161,6 @@ def test_reflection_pad(self): x = torch.randn(2, 2, 4, 4) self.run_test(model, x) - @disableScriptTest() # export prim::Uninitialized def test_replication_pad(self): model = torch.nn.ReplicationPad1d(2) x = torch.randn(2, 4, 4) @@ -4140,7 +4171,6 @@ def test_replication_pad(self): self.run_test(model, x) @skipIfUnsupportedMinOpsetVersion(11) - @disableScriptTest() # export prim::Uninitialized def test_im2col(self): class Unfold(torch.nn.Module): def forward(self, input): diff --git a/torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.cpp b/torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.cpp index c6c8eceb8f45..19523f51f2e1 100644 --- a/torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.cpp +++ b/torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.cpp @@ -1,4 +1,5 @@ #include +#include #include #include #include @@ -226,6 +227,111 @@ std::vector FixupONNXLoopNode(Node* node, int opset_version) { return new_outputs; } +// Check if node is prim::Uninitialized, +// or output of prim::Uninitialized->onnx::Identity +bool IsUninitializedNode(Node* n) { + if (n->kind() == ::c10::onnx::Identity && + n->inputs()[0]->node()->kind() == prim::Uninitialized) + return true; + if (n->kind() == prim::Uninitialized) + return true; + return false; +} + +// Infer shape and type of the uninitialized_output from the corresponding +// output of the other subblock. prim::Uninitialized node is proven to be +// unused. So replace this node with a constant of the inferred shape and type. +void InferShapeTypeForUninitializedOutput( + Graph* graph, + Block* block, + Value* uninitialized_output, + Value* other_output) { + auto output_type = other_output->type()->expect(); + auto elem_type = at::initialTensorOptions().dtype(output_type->scalarType()); + Node* const_node = graph->create(::c10::onnx::Constant, 1); + + if (output_type->sizes().concrete_sizes().has_value()) { + auto size = output_type->sizes().concrete_sizes().value(); + const_node->t_(attr::value, at::zeros(size, elem_type)); + const_node->output()->setType(other_output->type()); + const_node->output()->copyMetadata(other_output); + } else { + const_node->t_(attr::value, at::zeros({}, elem_type)); + const_node->output()->setType( + TensorType::create(*(output_type->scalarType()), at::kCPU, {}, {})); + } + const_node->insertBefore(block->return_node()); + uninitialized_output->replaceAllUsesWith(const_node->output()); + uninitialized_output->node()->destroy(); +} + +// Corresponding outputs for ONNX If then and else subblocks should have +// same shape and type. This pass detects if prim::Uninitialized node +// appears as part of outputs of either of the subblocks, and infers +// shape and type from the corresponding output of the other subblock +// In the example graph below, shape and type of the subblock output %7 +// for subblock 1 is inferred from %y.1. Shape and type of Subblock +// output %7 is inferred from %y.5. +// +// graph(%y.1 : Int(3:4, 4:1, requires_grad=0, device=cpu)): +// ... +// %7 : Tensor = prim::Uninitialized() +// %16 : bool, %17 : Tensor, %y.14 : Tensor = prim::If(%15) # +// test/onnx/test_pytorch_onnx_onnxruntime.py:614:20 +// block0(): +// %y.5 : Tensor = aten::add(%y.1, %3, %6) # +// test/onnx/test_pytorch_onnx_onnxruntime.py:615:28 +// -> (%2, %7, %y.5) +// block1(): +// -> (%1, %y.1, %7) +// ... + +void ONNXFixupUninitializedOutput(Node* node) { + if (node->kind() != ::c10::onnx::If) { + return; + } + + GRAPH_DUMP("Graph before fixing If shape type: ", node->owningGraph()); + auto* if_node = node; + auto* graph = if_node->owningGraph(); + + // Check if the input to ONNX If node is node Bool, and insert + // cast to Bool if needed. + if (!if_node->input()->type()->isSubtypeOf(BoolType::get())) { + Node* cast_node = CreateCastToBoolNode(if_node->input(), graph); + cast_node->insertBefore(if_node); + if_node->replaceInputWith(if_node->input(), cast_node->output()); + } + + Block* then_block = if_node->blocks()[0]; + Block* else_block = if_node->blocks()[1]; + + // Infer shape and type for subblock outputs + TORCH_INTERNAL_ASSERT( + then_block->outputs().size() == else_block->outputs().size()) + for (size_t i = 0; i < else_block->outputs().size(); i++) { + Value* then_block_output = then_block->outputs()[i]; + Value* else_block_output = else_block->outputs()[i]; + + // If both subblocks have an uninitialized output, shape and type cannot + // be inferred. + TORCH_CHECK( + !(IsUninitializedNode(then_block_output->node()) && + IsUninitializedNode(else_block_output->node())), + "Cannot infer shape and type for ONNX If with uninitialized output in both subblocks. Please check the model graph."); + + if (IsUninitializedNode(then_block_output->node())) { + InferShapeTypeForUninitializedOutput( + graph, then_block, then_block_output, else_block_output); + if_node->outputs()[i]->setType(then_block->outputs()[i]->type()); + } else if (IsUninitializedNode(else_block_output->node())) { + InferShapeTypeForUninitializedOutput( + graph, else_block, else_block_output, then_block_output); + if_node->outputs()[i]->setType(else_block->outputs()[i]->type()); + } + } +} + std::vector FixupONNXIfNode(Node* node, int opset_version) { if (node->kind() != ::c10::onnx::If) { return node->outputs().vec(); @@ -234,17 +340,17 @@ std::vector FixupONNXIfNode(Node* node, int opset_version) { auto* if_node = node; auto* graph = if_node->owningGraph(); for (Block* block : node->blocks()) { - if (block->nodes().begin() == block->nodes().end()) { - // ONNX does not support empty blocks, must use some op which does - // nothing - Value* output = block->outputs()[0]; - Node* id_node = graph->create(onnx::Identity); - id_node->insertBefore(block->return_node()); - id_node->addInput(output); - id_node->output()->copyMetadata(output); - block->return_node()->replaceInputWith(output, id_node->output()); + for (Value* output : block->outputs()) { + if (output->node()->owningBlock() != block) { + Node* id_node = graph->create(onnx::Identity); + id_node->insertBefore(block->return_node()); + id_node->addInput(output); + id_node->output()->copyMetadata(output); + block->return_node()->replaceInputWith(output, id_node->output()); + } } } + ONNXFixupUninitializedOutput(if_node); GRAPH_DUMP("Graph after fixing controlflow: ", node->owningGraph()); return if_node->outputs().vec(); } diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index 42580bb3f0f2..677015902d30 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -949,10 +949,11 @@ def _run_symbolic_function(g, n, inputs, env, operator_export_type=OperatorExpor else: raise RuntimeError("Unsupported prim::Constant kind: `{}`. Send a bug report.".format( n.kindOf("value"))) - elif n.mustBeNone() or op_name == "ListConstruct" or op_name == "ListUnpack": + elif n.mustBeNone() or op_name == "ListConstruct" or op_name == "ListUnpack" or op_name == "Uninitialized": # None is not an ONNX operator; keep it as None # Let the exporter handle and finally eliminate these ops # ListConstruct and ListUnpack will be erased in the ONNX peephole pass + # Uninitialized will be erased during shape/type inference return None elif op_name == "device" and n.output().type().kind() == "DeviceObjType": return None From 1abe6e5ad4384810160389200820d06e57f14944 Mon Sep 17 00:00:00 2001 From: shubhambhokare1 Date: Wed, 11 Nov 2020 12:43:00 -0800 Subject: [PATCH 20/93] [ONNX] Bool inputs to index_put updated symbolic (#46866) Summary: Cases with bool inputs to index_put nodes were handled for tracing purposes. This PR adds support for similar situations in scripting Pull Request resolved: https://github.com/pytorch/pytorch/pull/46866 Reviewed By: malfet Differential Revision: D24870818 Pulled By: bzinodev fbshipit-source-id: 2d75ca6f5f4b79d8c5ace337633c5aed3bdc4be7 --- .../jit/passes/onnx/preprocess_for_onnx.cpp | 80 +-------------- torch/onnx/symbolic_opset11.py | 99 +++++++++++++++++++ 2 files changed, 100 insertions(+), 79 deletions(-) diff --git a/torch/csrc/jit/passes/onnx/preprocess_for_onnx.cpp b/torch/csrc/jit/passes/onnx/preprocess_for_onnx.cpp index d82067d602fd..7a0014eb030d 100644 --- a/torch/csrc/jit/passes/onnx/preprocess_for_onnx.cpp +++ b/torch/csrc/jit/passes/onnx/preprocess_for_onnx.cpp @@ -72,7 +72,7 @@ void FuseWithListUnpack(Node* n) { Symbol::fromQualString("attr::_outputs"), static_cast(listUnpack_node->outputs().size())); - for (auto i = 0; i < listUnpack_node->outputs().size(); ++i) { + for (size_t i = 0; i < listUnpack_node->outputs().size(); ++i) { auto new_output = n->addOutput(); new_output->copyMetadata(listUnpack_node->output(i)); } @@ -159,83 +159,6 @@ static void ReplaceAddWithConcat(Block* b) { } } -// Replace aten::index_put_ with aten::masked_scatter or aten::masked_fill -// when inputs to the index_put node contains boolean inputs -// -// before the pass (index_put -> masked_fill): -// graph(%0 : Float(2:4, 2:2, 2:1, requires_grad=0, device=cpu)): -// %mask.1 : Float(2:4, 2:2, 2:1, requires_grad=0, device=cpu) -// %22 : Tensor?[] = prim::ListConstruct(%21) -// %23 : Float(requires_grad=0, device=cpu) = prim::Constant[value={1}]() -// %24 : bool = prim::Constant[value=0]() -// %mask : Float(2:4, 2:2, 2:1) = aten::index_put_(%mask.1, %22, %23, %24) -// -// after the pass -// graph(%0 : Float(2:4, 2:2, 2:1, requires_grad=0, device=cpu)): -// %46 : Float(requires_grad=0, device=cpu) = prim::Constant[value={5}]() -// %mask.1 : Float(2:4, 2:2, 2:1, requires_grad=0, device=cpu) = -// %23 : Float(requires_grad=0, device=cpu) = prim::Constant[value={1}]() -// %24 : bool = prim::Constant[value=0]() -// %49 : Tensor = aten::masked_fill(%mask.1, %21, %23) -// -// before the pass (index_put -> masked_scatter) -// %48 : Float(8:1, requires_grad=0, device=cpu) = prim::Constant[value= 1 1 -// 1 1 1 1 1 1 [ CPUFloatType{8} ]]() -// %42 : Tensor?[] = prim::ListConstruct(%41) -// %43 : bool = prim::Constant[value=0]() -// %44 : Float(2:4, 2:2, 2:1) = aten::index_put_(%mask, %42, %48, %43) -// return (%44) -// -// after the pass: -// %48 : Float(8:1, requires_grad=0, device=cpu) = prim::Constant[value= 1 1 -// 1 1 1 1 1 1 [ CPUFloatType{8} ]]() -// %49 : Tensor = aten::masked_fill(%mask.1, %21, %23) -// %41 : Bool(2:4, 2:2, 2:1) = aten::to() -// %50 : Tensor = aten::masked_scatter(%49, %41, %48) -// return (%50) -static void ReplaceIndexPutWithMaskedScatter(Block* b) { - for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) { - for (auto* child_block : it->blocks()) { - ReplaceIndexPutWithMaskedScatter(child_block); - } - if (it->kind() == aten::index_put_) { - auto* lc_node = it->input(1)->node(); - TensorTypePtr mask_tensor = lc_node->input(0)->type()->cast(); - if (!(mask_tensor) || !(mask_tensor->scalarType().has_value()) || - (mask_tensor->scalarType().value()) != c10::ScalarType::Bool) { - continue; - } - - if ((!lc_node->inputs().size()) == 1) { - continue; - } - - // If equated value is just a single scalar, then convert to masked_fill, - // and if value is a tensor of appropriate size, we convert to - // masked_scatter. - Node* masked_node; - TensorTypePtr val_tensor = it->input(2)->type()->cast(); - if ((val_tensor) && (val_tensor->sizes().size().has_value())) { - if ((val_tensor->sizes().size().value()) == 0) { - masked_node = b->owningGraph()->create(aten::masked_fill, 1); - } else { - masked_node = b->owningGraph()->create(aten::masked_scatter, 1); - } - } else { - continue; - } - - masked_node->insertBefore(*it); - masked_node->addInput(it->input(0)); - masked_node->addInput(lc_node->input(0)); - masked_node->addInput(it->input(2)); - it->replaceAllUsesWith(masked_node); - it->removeAllInputs(); - it.destroyCurrent(); - } - } -} - // This pass also covers the case when the input to ListUnpack // is int[] comming from some other op than ListConstruct (like Slice or Shape) // @@ -296,7 +219,6 @@ static void fuseListAndListUnpack(Block* b) { void PreprocessForONNX(std::shared_ptr& graph) { FuseWithListUnpack(graph->block()); ReplaceAddWithConcat(graph->block()); - ReplaceIndexPutWithMaskedScatter(graph->block()); fuseListAndListUnpack(graph->block()); } diff --git a/torch/onnx/symbolic_opset11.py b/torch/onnx/symbolic_opset11.py index f6acc4120dc2..8c3f740d8b70 100644 --- a/torch/onnx/symbolic_opset11.py +++ b/torch/onnx/symbolic_opset11.py @@ -80,6 +80,105 @@ def index_put(g, self, indices_list_value, values, accumulate=False): ] index = g.op("Concat", *indices_list, axis_i=-1) else: + # Replace index_put node with masked_scatter or masked_fill + # when inputs to the index_put node contains boolean inputs + # + # index_put -> masked_fill + # + # before graph(%0 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=1, device=cpu), + # %some_const : Float(requires_grad=0, device=cpu)): + # %6 : None = prim::Constant() + # %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6) + # %8 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::ne(%mask, %some_const) + # %26 : Long(requires_grad=0, device=cpu) = prim::Constant[value={11}]() + # %27 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() + # %11 : Device = prim::Constant[value="cpu"]() + # %12 : None = prim::Constant() + # %28 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() + # %29 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() + # %15 : None = prim::Constant() + # %16 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = + # aten::to(%8, %26, %27, %11, %12, %28, %29, %15) + # %18 : Float(requires_grad=0, device=cpu) = prim::Constant[value={1}]() + # %30 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() + # %22 : int[] = prim::Constant[value=[-1]]() + # %23 : Tensor = aten::view(%16, %22) + # %24 : Tensor?[] = prim::ListConstruct(%23) + # %25 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = + # aten::index_put(%mask, %24, %18, %30) + # return (%25) + # + # after graph(%0 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu), + # %some_const : Float(requires_grad=0, device=cpu)): + # %3 : Tensor = onnx::Equal(%0, %some_const) + # %4 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = onnx::Not(%3) + # %12 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = onnx::Cast[to=9](%4) + # %19 : Tensor = onnx::Cast[to=9](%12) + # %20 : Tensor = onnx::Constant[value={1}]() + # %21 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) + # = onnx::Where(%19, %20, %0) + # return (%21) + # + # index_put -> masked_scatter + # + # before graph(%0 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=1, device=cpu), + # %some_const : Float(requires_grad=0, device=cpu)): + # %6 : None = prim::Constant() + # %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6) + # %28 : Float(8, strides=[1], requires_grad=0, device=cpu) + # = prim::Constant[value= 1 1 1 1 1 1 1 1 [ CPUFloatType{8} ]]() + # %15 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) + # = aten::ne(%mask, %some_const) + # %34 : Long(requires_grad=0, device=cpu) = prim::Constant[value={11}]() + # %35 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() + # %18 : Device = prim::Constant[value="cpu"]() + # %19 : None = prim::Constant() + # %36 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() + # %37 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() + # %22 : None = prim::Constant() + # %23 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) + # = aten::to(%15, %34, %35, %18, %19, %36, %37, %22) + # %38 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() + # %30 : int[] = prim::Constant[value=[-1]]() + # %31 : Tensor = aten::view(%23, %30) + # %32 : Tensor?[] = prim::ListConstruct(%31) + # %33 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) + # = aten::index_put(%mask, %32, %28, %38) + # return (%33) + # + # after graph(%0 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu), + # %some_const : Float(requires_grad=0, device=cpu)): + # %3 : Float(8, strides=[1], requires_grad=0, device=cpu) + # = onnx::Constant[value= 1 1 1 1 1 1 1 1 [ CPUFloatType{8} ]]() + # %4 : Tensor = onnx::Equal(%0, %some_const) + # %5 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = onnx::Not(%4) + # %13 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = onnx::Cast[to=9](%5) + # %19 : Tensor = onnx::Shape(%0) + # %20 : Tensor = onnx::Expand(%13, %19) + # %21 : Tensor = onnx::NonZero(%20) + # %22 : Tensor = onnx::Transpose[perm=[1, 0]](%21) + # %23 : Tensor = onnx::Constant[value={-1}]() + # %24 : Tensor = onnx::Reshape(%3, %23) + # %25 : Tensor = onnx::Shape(%22) + # %27 : Tensor = onnx::Constant[value={0}]() + # %28 : Tensor = onnx::Gather[axis=0](%25, %27) + # %29 : Tensor = onnx::Constant[value={0}]() + # %30 : Tensor = onnx::Unsqueeze[axes=[0]](%29) + # %31 : Tensor = onnx::Unsqueeze[axes=[0]](%28) + # %32 : Tensor = onnx::Constant[value={0}]() + # %33 : Tensor = onnx::Unsqueeze[axes=[0]](%32) + # %34 : Tensor = onnx::Slice(%24, %30, %31, %33) + # %35 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) + # = onnx::ScatterND(%0, %22, %34) + # return (%35) + + bool_inp = list(index.node().inputs())[0] + if bool_inp.type() is not None and bool_inp.type().scalarType() == 'Bool': + if values.type() is not None: + if values.type().dim() == 0: + from torch.onnx.symbolic_opset9 import masked_fill + return masked_fill(g, self, bool_inp, values) + return masked_scatter(g, self, bool_inp, values) broadcast_index_shape = g.op("Shape", index) index = g.op("Unsqueeze", index, axes_i=[-1]) sub_data_shape = sym_help._slice_helper( From 6c815c71b30e36a55b6b8ae2e78e3ff612c9c98f Mon Sep 17 00:00:00 2001 From: Shen Li Date: Wed, 11 Nov 2020 13:01:55 -0800 Subject: [PATCH 21/93] Revert to use NCCL 2.7.8-1 (#47638) Summary: Only depend on stable NCCL releases Pull Request resolved: https://github.com/pytorch/pytorch/pull/47638 Reviewed By: mingzhe09088 Differential Revision: D24847765 Pulled By: mrshenli fbshipit-source-id: 2c5f29602aa7403c110797cb07f8fb6151a1b60d --- third_party/nccl/nccl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/nccl/nccl b/third_party/nccl/nccl index 31b5bb6f6447..033d799524fb 160000 --- a/third_party/nccl/nccl +++ b/third_party/nccl/nccl @@ -1 +1 @@ -Subproject commit 31b5bb6f6447da98b9110c605465f9c09621074e +Subproject commit 033d799524fb97629af5ac2f609de367472b2696 From fc24d0656afcda4d57538f129a03bdf62506021c Mon Sep 17 00:00:00 2001 From: Richard Zou Date: Wed, 11 Nov 2020 14:01:45 -0800 Subject: [PATCH 22/93] Tensor.contiguous, Tensor.is_contiguous batch rule (#47621) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47621 Followup to #47365. is_contiguous on BatchedTensorImpl is implemented as: - Whenever one creates a BatchedTensorImpl, we cache the strides of the per-examples, just like how we cache the sizes of the per-examples. - With the cached strides, we use TensorImpl::refresh_contiguous() to compute if the tensor is contiguous or not. - is_contiguous checks the `is_contiguous_` flag that refresh_contiguous() populates. Both contiguous and is_contiguous only support torch.contiguous_format. I'm not sure what the semantics should be for other memory formats; they are also rank dependent (e.g., channels_last tensor must have 4 dimensions) which makes this a bit tricky. Test Plan: - new tests Reviewed By: Chillee, anjali411 Differential Revision: D24840975 Pulled By: zou3519 fbshipit-source-id: 4d86dbf11e2eec45f3f08300ae3f2d79615bb99d --- aten/src/ATen/BatchedTensorImpl.cpp | 12 +++- aten/src/ATen/BatchingRegistrations.cpp | 11 ++++ aten/src/ATen/test/vmap_test.cpp | 2 +- test/test_vmap.py | 76 +++++++++++++++++++++++++ 4 files changed, 99 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/BatchedTensorImpl.cpp b/aten/src/ATen/BatchedTensorImpl.cpp index 3c2ce5b9a671..8f373b1ea29b 100644 --- a/aten/src/ATen/BatchedTensorImpl.cpp +++ b/aten/src/ATen/BatchedTensorImpl.cpp @@ -19,13 +19,18 @@ BatchedTensorImpl::BatchedTensorImpl(Tensor value, BatchDims bdims) const auto public_dims = value_.dim() - bdims_.size(); const auto value_sizes = value_.sizes(); + const auto value_strides = value_.strides(); sizes_.clear(); sizes_.reserve(public_dims); + strides_.clear(); + strides_.reserve(public_dims); for (int64_t dim = 0; dim < public_dims; dim++) { auto actual_dim = actualDim(dim, /*wrap_dim=*/false); sizes_.push_back(value_sizes.at(actual_dim)); + strides_.push_back(value_strides.at(actual_dim)); } refresh_numel(); + refresh_contiguous(); } int64_t BatchedTensorImpl::actualDim(int64_t dim, bool wrap_dim) const { @@ -77,9 +82,14 @@ IntArrayRef BatchedTensorImpl::strides() const { int64_t BatchedTensorImpl::stride(int64_t d) const { TORCH_CHECK(false, "NYI: Getting tensor strides inside of vmap"); } + bool BatchedTensorImpl::is_contiguous(at::MemoryFormat memory_format) const { - TORCH_CHECK(false, "NYI: querying is_contiguous inside of vmap"); + TORCH_CHECK(memory_format == MemoryFormat::Contiguous, + "NYI: querying is_contiguous inside of vmap for memory_format ", + "other than torch.contiguous_format"); + return is_contiguous_; } + const Storage& BatchedTensorImpl::storage() const { TORCH_CHECK(false, "Due to limitations, we cannot access the storage() of a tensor from inside of vmap."); } diff --git a/aten/src/ATen/BatchingRegistrations.cpp b/aten/src/ATen/BatchingRegistrations.cpp index bfd06bfe5f75..3c025e10af99 100644 --- a/aten/src/ATen/BatchingRegistrations.cpp +++ b/aten/src/ATen/BatchingRegistrations.cpp @@ -338,6 +338,15 @@ Tensor unfold_batching_rule(const Tensor& self, int64_t dim, int64_t size, int64 return self_physical.newLogicalFromPhysical(result); } +Tensor contiguous_batching_rule(const Tensor& self, MemoryFormat memory_format) { + TORCH_CHECK(memory_format == MemoryFormat::Contiguous, + "NYI: Tensor.contiguous(...) inside of vmap for memory_format other ", + "than torch.contiguous_format"); + auto physical_view = MultiBatchVmapTransform::logicalToPhysical(self); + auto result = physical_view.tensor().contiguous(memory_format); + return physical_view.newLogicalFromPhysical(result); +} + Tensor view_batching_rule(const Tensor& self, IntArrayRef size) { auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); auto size_physical = self_physical.getPhysicalShape(size); @@ -1050,6 +1059,8 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) { m.impl_UNBOXED("new_empty", new_empty_batching_rule); m.impl_UNBOXED("new_empty_strided", new_empty_strided_batching_rule); m.impl("new_zeros", new_zeros_batching_rule); + + m.impl("contiguous", contiguous_batching_rule); } } // namespace at diff --git a/aten/src/ATen/test/vmap_test.cpp b/aten/src/ATen/test/vmap_test.cpp index 32aadc25b383..99845b5df0ae 100644 --- a/aten/src/ATen/test/vmap_test.cpp +++ b/aten/src/ATen/test/vmap_test.cpp @@ -15,8 +15,8 @@ TEST(VmapTest, TestBatchedTensor) { ASSERT_EQ(x.sizes(), expected_size); ASSERT_EQ(x.dim(), 2); ASSERT_EQ(x.numel(), 8); + ASSERT_EQ(x.is_contiguous(), false); ASSERT_THROW(x.strides(), c10::Error); - ASSERT_THROW(x.is_contiguous(), c10::Error); ASSERT_THROW(x.storage(), c10::Error); ASSERT_THROW(x.storage_offset(), c10::Error); } diff --git a/test/test_vmap.py b/test/test_vmap.py index 2e4e1b138123..8a40860645a6 100644 --- a/test/test_vmap.py +++ b/test/test_vmap.py @@ -1265,6 +1265,26 @@ def get(shape): result = vmap(op)(real_tensor) self.assertEqual(result.data_ptr(), real_tensor.data_ptr()) + def test_contiguous(self): + op = Tensor.contiguous + + self._test_unary(op, TensorFactory.randn, 'cpu') + + # check that contiguous returns the original tensor if the per-examples + # are already contiguous + B0 = 3 + x = torch.randn(B0, 2, 5, 7) + x = x.movedim(0, 2) + result = vmap(Tensor.contiguous, in_dims=2, out_dims=2)(x) + self.assertTrue(result is x) + + msg = 'NYI: querying is_contiguous inside of vmap for memory_format' + tensor = torch.randn(B0, 3) + with self.assertRaisesRegex(RuntimeError, msg): + vmap(functools.partial(op, memory_format=torch.channels_last))(tensor) + with self.assertRaisesRegex(RuntimeError, msg): + vmap(functools.partial(op, memory_format=torch.channels_last_3d))(tensor) + def test_chunk(self): test = self._vmap_view_test op = torch.chunk @@ -1432,6 +1452,62 @@ def foo(x): self.assertEqual(vmap(foo)(ctensor), torch.tensor([1, 1, 1])) self.assertEqual(vmap(foo)(tensor), torch.tensor([0, 0, 0])) + def test_is_contiguous(self): + def foo(x): + if x.is_contiguous(): + return torch.tensor(1.) + else: + return torch.tensor(0.) + + B0, B1 = 3, 5 + + # Single batch dim + contig = torch.randn(B0, 2, 7) + self.assertEqual(vmap(foo)(contig), torch.ones(B0)) + + noncontig = torch.randn(2, B0, 7) + self.assertEqual(vmap(foo, in_dims=1)(noncontig), torch.zeros(B0)) + + noncontig = torch.randn(2, B0, 7).movedim(1, 0) + self.assertEqual(vmap(foo)(noncontig), torch.zeros(B0)) + + noncontig = torch.randn(2, 7, B0) + self.assertEqual(vmap(foo, in_dims=2)(noncontig), torch.zeros(B0)) + + # Multiple batch dims + contig = torch.randn(B0, B1, 3) + self.assertEqual(vmap(vmap(foo))(contig), torch.ones(B0, B1)) + + contig = torch.randn(B1, B0, 3) + self.assertEqual(vmap(vmap(foo), in_dims=1)(contig), torch.ones(B0, B1)) + + contig = torch.randn(B1, B0, 3).movedim(0, 1) + self.assertEqual(vmap(vmap(foo))(contig), torch.ones(B0, B1)) + + noncontig = torch.randn(B0, 3, B1) + self.assertEqual(vmap(vmap(foo, in_dims=1))(noncontig), torch.zeros(B0, B1)) + + # is_contiguous on empty tensor is True + def bar(x): + assert x.is_contiguous() + return x + + vmap(bar)(torch.randn(B0, 0, 3)) + vmap(bar, in_dims=1)(torch.randn(0, B0, 3)) + vmap(bar)(torch.randn(B0, 0, 3).transpose(-1, -2)) + + # is_contiguous with other memory formats + def baz(x, memory_format): + x.is_contiguous(memory_format=memory_format) + return x + + msg = 'NYI: querying is_contiguous inside of vmap for memory_format' + tensor = torch.randn(B0, 2, 7, 3) + with self.assertRaisesRegex(RuntimeError, msg): + vmap(functools.partial(baz, memory_format=torch.channels_last))(tensor) + with self.assertRaisesRegex(RuntimeError, msg): + vmap(functools.partial(baz, memory_format=torch.channels_last_3d))(tensor) + def test_movedim(self): op = torch.movedim test = self._vmap_view_test From f6ff6478cf7b963b650a3621a78e42cfc8265b00 Mon Sep 17 00:00:00 2001 From: Richard Zou Date: Wed, 11 Nov 2020 14:01:45 -0800 Subject: [PATCH 23/93] Make kwargs argument optional in _batched_grad_test (#47625) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47625 kwargs is {} most of the time so this PR makes it optional. Note that it is bad practice for {} to be a default argument; we work around this by using None as the default and handling it accordingly. Test Plan - `pytest test/test_vmap.py -v` Test Plan: Imported from OSS Reviewed By: Chillee Differential Revision: D24842571 Pulled By: zou3519 fbshipit-source-id: a46b0c6d5240addbe3b231b8268cdc67708fa9e0 --- test/test_vmap.py | 70 +++++++++++++++++++++++++---------------------- 1 file changed, 37 insertions(+), 33 deletions(-) diff --git a/test/test_vmap.py b/test/test_vmap.py index 8a40860645a6..88c95872414a 100644 --- a/test/test_vmap.py +++ b/test/test_vmap.py @@ -2030,7 +2030,9 @@ def _vmap_test(self, *args, **kwargs): # output_process_fn: a function that maps the outputs to the part # that should be differentiated. # batch_size: the batch dim size for the batched grad - def _batched_grad_test(self, op, args, kwargs, output_process_fn=lambda x: x, batch_size=3): + def _batched_grad_test(self, op, args, kwargs=None, output_process_fn=lambda x: x, batch_size=3): + if kwargs is None: + kwargs = {} outputs = op(*args, **kwargs) outputs = differentiable(output_process_fn(outputs)) batched_vectors = tuple(construct_v(out, batch_size) for out in outputs) @@ -2054,7 +2056,9 @@ def vector_jacobian_product(*vectors): # Regression. # It might be useful to have a test that computes batched first gradients and # then uses those to compute batched second gradients in the future. - def _batched_grad_grad_test(self, op, args, kwargs, output_process_fn=lambda x: x, batch_size=3): + def _batched_grad_grad_test(self, op, args, kwargs=None, output_process_fn=lambda x: x, batch_size=3): + if kwargs is None: + kwargs = {} outputs = op(*args, **kwargs) outputs = differentiable(output_process_fn(outputs)) ones = tuple(torch.ones_like(out) for out in outputs) @@ -2081,12 +2085,12 @@ def _test_arithmetic(self, op, device, test_grad_grad=True): x = torch.randn(2, 3, requires_grad=True, device=device) y = _get_rand_no_zeros(2, 3, device=device, requires_grad=True) scalar = 3.14 - self._batched_grad_test(op, (x, y), {}) - self._batched_grad_test(op, (scalar, y), {}) - self._batched_grad_test(op, (x, scalar), {}) + self._batched_grad_test(op, (x, y)) + self._batched_grad_test(op, (scalar, y)) + self._batched_grad_test(op, (x, scalar)) if test_grad_grad: - self._batched_grad_grad_test(op, (x, y), {}) + self._batched_grad_grad_test(op, (x, y)) def test_add(self, device): self._test_arithmetic(torch.add, device, test_grad_grad=False) @@ -2109,7 +2113,7 @@ def test_expand(self, device): def op(x): return x.expand(5, 5, 2, 3) - self._batched_grad_test(op, (x,), {}) + self._batched_grad_test(op, (x,)) @allowVmapFallbackUsage def test_index(self, device): @@ -2120,18 +2124,18 @@ def op(x): y = x * x return y[index] - self._batched_grad_test(op, (x,), {}) - self._batched_grad_grad_test(op, (x,), {}) + self._batched_grad_test(op, (x,)) + self._batched_grad_grad_test(op, (x,)) def test_lgamma(self, device): x = torch.randn(2, 3, requires_grad=True, device=device) - self._batched_grad_test(Tensor.lgamma, (x,), {}) - self._batched_grad_grad_test(Tensor.lgamma, (x,), {}) + self._batched_grad_test(Tensor.lgamma, (x,)) + self._batched_grad_grad_test(Tensor.lgamma, (x,)) def test_log(self, device): x = _get_rand_no_zeros(2, 3, device=device, requires_grad=True) - self._batched_grad_test(torch.log, (x,), {}) - self._batched_grad_grad_test(torch.log, (x,), {}) + self._batched_grad_test(torch.log, (x,)) + self._batched_grad_grad_test(torch.log, (x,)) def test_logsumexp(self, device): x = _get_rand_no_zeros(2, 3, device=device, requires_grad=True) @@ -2139,28 +2143,28 @@ def test_logsumexp(self, device): def op(x): return torch.logsumexp(x, -1) - self._batched_grad_test(op, (x,), {}) - self._batched_grad_grad_test(op, (x,), {}) + self._batched_grad_test(op, (x,)) + self._batched_grad_grad_test(op, (x,)) def test_log1p(self, device): x = _get_rand_no_zeros(2, 3, device=device, requires_grad=True) - self._batched_grad_test(torch.log1p, (x,), {}) - self._batched_grad_grad_test(torch.log1p, (x,), {}) + self._batched_grad_test(torch.log1p, (x,)) + self._batched_grad_grad_test(torch.log1p, (x,)) @allowVmapFallbackUsage def test_max(self, device): x = torch.randn(2, 3, requires_grad=True, device=device) - self._batched_grad_test(torch.max, (x,), {}) + self._batched_grad_test(torch.max, (x,)) @allowVmapFallbackUsage def test_median(self, device): x = torch.randn(2, 3, requires_grad=True, device=device) - self._batched_grad_test(torch.median, (x,), {}) + self._batched_grad_test(torch.median, (x,)) @allowVmapFallbackUsage def test_min(self, device): x = torch.randn(2, 3, requires_grad=True, device=device) - self._batched_grad_test(torch.min, (x,), {}) + self._batched_grad_test(torch.min, (x,)) def test_permute(self, device): x = torch.randn(2, 3, 5, requires_grad=True, device=device) @@ -2168,7 +2172,7 @@ def test_permute(self, device): def op(x): return x.permute(2, 0, 1) - self._batched_grad_test(op, (x,), {}) + self._batched_grad_test(op, (x,)) def test_reshape(self, device): x = torch.randn(2, 3, 5, requires_grad=True, device=device) @@ -2176,12 +2180,12 @@ def test_reshape(self, device): def op(x): return x.reshape([2 * 3, 5]) - self._batched_grad_test(op, (x,), {}) + self._batched_grad_test(op, (x,)) def test_sigmoid(self, device): x = torch.randn(2, 3, requires_grad=True, device=device) - self._batched_grad_test(Tensor.sigmoid, (x,), {}) - self._batched_grad_grad_test(Tensor.sigmoid, (x,), {}) + self._batched_grad_test(Tensor.sigmoid, (x,)) + self._batched_grad_grad_test(Tensor.sigmoid, (x,)) def test_stack(self, device): x = torch.randn(2, 3, device=device, requires_grad=True) @@ -2189,19 +2193,19 @@ def test_stack(self, device): def op(x, y): return torch.stack([x, y]) - self._batched_grad_test(op, (x, y), {}) + self._batched_grad_test(op, (x, y)) def test_select(self, device): x = torch.randn(2, 3, device=device, requires_grad=True) - self._batched_grad_test(lambda x: x[1], (x,), {}) - self._batched_grad_test(lambda x: x.select(1, 2), (x,), {}) - self._batched_grad_test(lambda x: x.select(-1, 0), (x,), {}) + self._batched_grad_test(lambda x: x[1], (x,)) + self._batched_grad_test(lambda x: x.select(1, 2), (x,)) + self._batched_grad_test(lambda x: x.select(-1, 0), (x,)) def test_slice(self, device): x = torch.randn(2, 3, 5, device=device, requires_grad=True) - self._batched_grad_test(lambda x: x[0:1], (x,), {}) - self._batched_grad_test(lambda x: x[:, 1:3], (x,), {}) - self._batched_grad_test(lambda x: x[..., 1:3], (x,), {}) + self._batched_grad_test(lambda x: x[0:1], (x,)) + self._batched_grad_test(lambda x: x[:, 1:3], (x,)) + self._batched_grad_test(lambda x: x[..., 1:3], (x,)) @allowVmapFallbackUsage def test_inplace_view(self, device): @@ -2236,10 +2240,10 @@ def func(leaf): def test_diagonal(self, device): x = torch.randn(4, 5, device=device, requires_grad=True) - self._batched_grad_test(lambda x: x.diagonal(1, 0, 1), (x,), {}) + self._batched_grad_test(lambda x: x.diagonal(1, 0, 1), (x,)) x = torch.randn(3, 4, 5, device=device, requires_grad=True) - self._batched_grad_test(lambda x: x.diagonal(0, -1, -2), (x,), {}) + self._batched_grad_test(lambda x: x.diagonal(0, -1, -2), (x,)) instantiate_device_type_tests( TestVmapBatchedGradient, From df887936a42e8ed102abd301d9f398797791a532 Mon Sep 17 00:00:00 2001 From: Richard Zou Date: Wed, 11 Nov 2020 14:01:45 -0800 Subject: [PATCH 24/93] Fix transpose batching rule (#47628) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47628 Pytorch has a special case where scalar_tensor.transpose(0, 0) works and returns the scalar tensor. If the following happens: ```py >>> x = torch.randn(B0) # the per-examples are all scalars >>> vmap(lambda x: x.transpose(0, 0), x) ``` then we replicate this behavior Test Plan: - new tests Reviewed By: anjali411 Differential Revision: D24843658 Pulled By: zou3519 fbshipit-source-id: e33834122652473e34a18ca1cecf98e8a3b84bc1 --- aten/src/ATen/BatchingRegistrations.cpp | 19 +++++++++++++++++-- test/test_vmap.py | 22 ++++++++++++++++++++++ 2 files changed, 39 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/BatchingRegistrations.cpp b/aten/src/ATen/BatchingRegistrations.cpp index 3c025e10af99..d6de7a6cb125 100644 --- a/aten/src/ATen/BatchingRegistrations.cpp +++ b/aten/src/ATen/BatchingRegistrations.cpp @@ -49,13 +49,19 @@ namespace at { // if not use the same mechanism. In order to accomplish that we might have to // do some refactoring. +// PyTorch allows operations to specify dim 0 and dim -1 on a scalar tensor. +static bool is_allowed_dim_on_scalar_tensor(int64_t dim) { + return dim == 0 || dim == -1; +} + Tensor sum_batching_rule(const Tensor& self, IntArrayRef dims, bool keepdim, optional dtype) { // PyTorch has a special case where sum(scalar_tensor, dim=0) does not fail - // and instead returns a new scalar tensor. If the following happens: + // and instead returns a new scalar tensor (this also happens for dim=-1) + // If the following happens: // >>> x = torch.randn(B0) # the per-examples are all scalars // >>> vmap(partial(torch.sum, dim=0), x) // then we replicate the behavior of sum(scalar_tensor, dim=0). - if (/*logical*/self.dim() == 0 && dims.size() == 1 && dims[0] == 0) { + if (/*logical*/self.dim() == 0 && dims.size() == 1 && is_allowed_dim_on_scalar_tensor(dims[0])) { return self.clone(); } auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); @@ -217,6 +223,15 @@ Tensor squeeze_dim_batching_rule(const Tensor& self, int64_t dim) { } Tensor transpose_int_batching_rule(const Tensor& self, int64_t dim0, int64_t dim1) { + // PyTorch has a special case where scalar_tensor.transpose(dim0, dim1) works + // for dim0, dim1 in {0, -1} and returns the scalar tensor. If the following happens: + // >>> x = torch.randn(B0) # the per-examples are all scalars + // >>> vmap(lambda x: x.transpose(0, -1), x) + // then we replicate this behavior. + if (/*logical*/self.dim() == 0 && is_allowed_dim_on_scalar_tensor(dim0) && + is_allowed_dim_on_scalar_tensor(dim1)) { + return self; + } auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); auto dim0_physical = self_physical.getPhysicalDim(dim0); auto dim1_physical = self_physical.getPhysicalDim(dim1); diff --git a/test/test_vmap.py b/test/test_vmap.py index 88c95872414a..9f5725870675 100644 --- a/test/test_vmap.py +++ b/test/test_vmap.py @@ -2,6 +2,7 @@ import torch from torch import Tensor, vmap import functools +import itertools import warnings from torch.testing._internal.common_device_type import instantiate_device_type_tests from torch.testing._internal.common_utils import TEST_WITH_ROCM @@ -1710,12 +1711,14 @@ def test_sum_dim(self): # Single vmap, various in_dims / out_dims test(lambda x: x.sum(0), [torch.randn([B0])]) + test(lambda x: x.sum(-1), [torch.randn([B0])]) test(lambda x: x.sum(0), [torch.randn([B0, 3])]) test(lambda x: x.sum(-1), [torch.randn([2, 5, B0, 3])], in_dims=2) test(lambda x: x.sum(2), [torch.randn([2, 5, B0, 3])], in_dims=2, out_dims=2) # Doubly nested vmap test(vmap(lambda x: x.sum(0)), [torch.randn([B0, B1])]) + test(vmap(lambda x: x.sum(-1)), [torch.randn([B0, B1])]) test(vmap(lambda x: x.sum(-2)), [torch.randn([B1, 2, 5, B0, 3])], in_dims=2) test(vmap(lambda x: x.sum(2), in_dims=2), [torch.randn([2, 5, B0, B1, 3])], in_dims=2, out_dims=2) @@ -1830,6 +1833,25 @@ def test_split(self): test(vmap(vmap(lambda t: op(t, [4] * 8 + [8] * 4, 1), in_dims=2)), (torch.rand(B1, 2, B0, 64, B2),), in_dims=2) + def test_transpose(self): + op = torch.transpose + test = self._vmap_view_test + + B0, B1, B2 = 7, 11, 13 + test(lambda x: op(x, 0, 1), (torch.rand(B0, 2, 5),)) + test(lambda x: op(x, -1, -2), (torch.rand(B0, 2, 5),)) + test(lambda x: op(x, 3, 1), (torch.rand(B0, 2, 5, 4, 6),)) + test(lambda x: op(x, 1, 0), (torch.rand(2, B0, 5),), in_dims=1) + test(vmap(lambda x: op(x, 0, 1)), (torch.rand(B1, 2, B0, 5),), in_dims=2) + test(vmap(vmap(lambda x: op(x, 0, 1), in_dims=2)), + (torch.rand(B1, 2, B0, 5, B2),), in_dims=2) + + # Special case: scalar tensor + for dim1, dim2 in itertools.product([0, -1], [0, -1]): + x = torch.rand(B0) + result = vmap(lambda x: op(x, dim1, dim2))(x) + self.assertTrue(result is x) + def test_t(self): op = torch.t test = self._vmap_view_test From 05a76ed705b75655605987f699cf67df50180011 Mon Sep 17 00:00:00 2001 From: Richard Zou Date: Wed, 11 Nov 2020 14:01:45 -0800 Subject: [PATCH 25/93] Batching rule for torch.squeeze(tensor) (#47632) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47632 This one is fun because we have to be careful not to squeeze out any of the batch dims (it is the dims of the per-example tensor that are being squeezed). Test Plan: - new tests Reviewed By: anjali411 Differential Revision: D24859022 Pulled By: zou3519 fbshipit-source-id: 8adbd80963081efb683f62ea074a286a10da288f --- aten/src/ATen/BatchingRegistrations.cpp | 22 ++++++++++++++++++++++ test/test_vmap.py | 12 ++++++++++++ 2 files changed, 34 insertions(+) diff --git a/aten/src/ATen/BatchingRegistrations.cpp b/aten/src/ATen/BatchingRegistrations.cpp index d6de7a6cb125..c30ddb631d0a 100644 --- a/aten/src/ATen/BatchingRegistrations.cpp +++ b/aten/src/ATen/BatchingRegistrations.cpp @@ -215,6 +215,27 @@ Tensor unsqueeze_batching_rule(const Tensor& self, int64_t dim) { return self_physical.newLogicalFromPhysical(result); } +Tensor squeeze_batching_rule(const Tensor& self) { + auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); + auto physical_sizes = self_physical.tensor().sizes(); + + // Don't squeeze the batch dims! + VmapDimVector squeezed_sizes; + int64_t num_batch_dims = self_physical.numBatchDims(); + squeezed_sizes.insert( + squeezed_sizes.end(), + physical_sizes.begin(), + physical_sizes.begin() + num_batch_dims); + for (auto it = physical_sizes.begin() + num_batch_dims; it != physical_sizes.end(); ++it) { + if (*it != 1) { + squeezed_sizes.push_back(*it); + } + } + + auto result = self_physical.tensor().view(squeezed_sizes); + return self_physical.newLogicalFromPhysical(result); +} + Tensor squeeze_dim_batching_rule(const Tensor& self, int64_t dim) { auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); auto dim_physical = self_physical.getPhysicalDim(dim); @@ -950,6 +971,7 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) { m.impl("slice.Tensor", slice_batching_rule); m.impl("split.Tensor", split_batching_rule); m.impl("split_with_sizes", split_with_sizes_batching_rule); + m.impl("squeeze", squeeze_batching_rule); m.impl("squeeze.dim", squeeze_dim_batching_rule); m.impl("t", native::t); // composite wrt autograd m.impl("transpose.int", transpose_int_batching_rule); diff --git a/test/test_vmap.py b/test/test_vmap.py index 9f5725870675..5f67498f1437 100644 --- a/test/test_vmap.py +++ b/test/test_vmap.py @@ -1705,6 +1705,18 @@ def test_slice(self): test(vmap(vmap(lambda t: t[0:1], in_dims=2), in_dims=2), (torch.rand(3, 5, B0, B1, B2),), in_dims=2) + def test_squeeze(self): + test = self._vmap_view_test + op = torch.squeeze + B0, B1 = 1, 11 + test(op, (torch.rand(B0),)) + test(op, (torch.rand(B0, 3, 5),)) + test(op, (torch.rand(1, B0, 5),), in_dims=1) + test(op, (torch.rand(B0, 0, 1, 5, 1),)) + test(op, (torch.rand(B0, 1, 1, 1, 1),)) + test(vmap(op), (torch.rand(B0, B1, 1),)) + test(vmap(op), (torch.rand(B1, 1, B0),), in_dims=2) + def test_sum_dim(self): test = self._vmap_test B0, B1 = 5, 7 From 7864ae9f987f3ce848422f08d849604a3ea00343 Mon Sep 17 00:00:00 2001 From: Sebastian Messmer Date: Wed, 11 Nov 2020 14:18:05 -0800 Subject: [PATCH 26/93] Improve error messages for operator registration API (#47636) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47636 Previously: ``` terminate called after throwing an instance of 'c10::Error' what(): *cpp_signature == cpp_signature_->signature INTERNAL ASSERT FAILED at "caffe2/aten/src/ATen/core/dispatch/OperatorEntry.cpp":92, please report a bug to PyTorch. Tried to register a kernel (registered at buck-out/dev/gen/caffe2/generate-code/autograd/generated/TraceType_2.cpp:9847) for operator aten::div.out (registered at buck-out/dev/gen/caffe2/aten/gen_aten=TypeDefault.cpp/TypeDefault.cpp:3541) for dispatch key Tracer, but the C++ function signature at::Tensor& (at::Tensor const&, at::Tensor const&, at::Tensor&) mismatched with a previous kernel (registered at buck-out/dev/gen/caffe2/aten/gen_aten=CPUType.cpp/CPUType.cpp:2166) that had the signature at::Tensor& (at::Tensor&, at::Tensor const&, at::Tensor const&) ``` Now: ``` terminate called after throwing an instance of 'c10::Error' what(): *cpp_signature == cpp_signature_->signature INTERNAL ASSERT FAILED at "caffe2/aten/src/ATen/core/dispatch/OperatorEntry.cpp":96, please report a bug to PyTorch. Mismatch in kernel C++ signatures operator: aten::div.out(Tensor self, Tensor other, *, Tensor(a!) out) -> (Tensor(a!)) registered at buck-out/dev/gen/caffe2/aten/gen_aten=TypeDefault.cpp/TypeDefault.cpp:3541 kernel 1: at::Tensor& (at::Tensor&, at::Tensor const&, at::Tensor const&) dispatch key: CPU registered at buck-out/dev/gen/caffe2/aten/gen_aten=CPUType.cpp/CPUType.cpp:2166 kernel 2: at::Tensor& (at::Tensor const&, at::Tensor const&, at::Tensor&) dispatch key: Tracer registered at buck-out/dev/gen/caffe2/generate-code/autograd/generated/TraceType_2.cpp:9847 ``` Previously: ``` W1109 13:38:52.464170 1644302 OperatorEntry.cpp:117] Warning: Registering a kernel (registered at caffe2/torch/csrc/autograd/VariableTypeManual.cpp:310) for operator aten::_backward (registered at buck-out/dev/gen/caffe2/aten/gen_aten=TypeDefault.cpp/TypeDefault.cpp:3549) for dispatch key Autograd that overwrote a previously registered kernel (registered at caffe2/torch/csrc/autograd/VariableTypeManual.cpp:310) with the same dispatch key for the same operator. (function registerKernel) ``` Now: ``` W1109 13:49:40.501817 1698959 OperatorEntry.cpp:118] Warning: Overriding a previously registered kernel for the same operator and the same dispatch key operator: aten::_backward(Tensor self, Tensor[] inputs, Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False) -> () registered at buck-out/dev/gen/caffe2/aten/gen_aten=TypeDefault.cpp/TypeDefault.cpp:3549 dispatch key: Autograd previous kernel: registered at caffe2/torch/csrc/autograd/VariableTypeManual.cpp:310 new kernel: registered at caffe2/torch/csrc/autograd/VariableTypeManual.cpp:310 (function registerKernel) ``` Previously: ``` terminate called after throwing an instance of 'c10::Error' what(): In registration for dummy_library::dummy_op: expected schema of operator to be "dummy_library::dummy_op(Tensor a) -> (Tensor)" (registered at caffe2/torch/csrc/autograd/VariableTypeManual.cpp:298), but got inferred schema "(Tensor _0) -> ()" (registered at caffe2/torch/csrc/autograd/VariableTypeManual.cpp:298). The number of returns is different. 1 vs 0 ``` Now: ``` terminate called after throwing an instance of 'c10::Error' what(): Inferred operator schema for a C++ kernel function doesn't match the expected function schema. operator: dummy_library::dummy_op expected schema: dummy_library::dummy_op(Tensor a) -> (Tensor) registered at caffe2/torch/csrc/autograd/VariableTypeManual.cpp:298 inferred schema: (Tensor _0) -> () registered at caffe2/torch/csrc/autograd/VariableTypeManual.cpp:298 reason: The number of returns is different. 1 vs 0 ```` Previously: ``` terminate called after throwing an instance of 'c10::Error' what(): !cpp_signature_.has_value() || (CppSignature::make() == cpp_signature_->signature) INTERNAL ASSERT FAILED at "caffe2/aten/src/ATen/core/dispatch/OperatorEntry.h":170, please report a bug to PyTorch. Tried to access operator _test::dummy with a wrong signature. Accessed with void (at::Tensor, long) but the operator was registered with void (at::Tensor) (schema: registered by RegisterOperators, kernel: registered by RegisterOperators) This likely happened in a call to OperatorHandle::typed(). Please make sure that the function signature matches the signature in the operator registration call. ``` Now: ``` terminate called after throwing an instance of 'c10::Error' what(): !cpp_signature_.has_value() || (CppSignature::make() == cpp_signature_->signature) INTERNAL ASSERT FAILED at "caffe2/aten/src/ATen/core/dispatch/OperatorEntry.h":169, please report a bug to PyTorch. Tried to access or call an operator with a wrong signature. operator: _test::dummy(Tensor dummy) -> () registered by RegisterOperators correct signature: void (at::Tensor) registered by RegisterOperators accessed/called as: void (at::Tensor, long) This likely happened in a call to OperatorHandle::typed(). Please make sure that the function signature matches the signature in the operator registration call. ``` ghstack-source-id: 116359052 Test Plan: waitforsandcastle Reviewed By: ezyang Differential Revision: D24846523 fbshipit-source-id: 0ce7d487b725bfbdf2261e36027cb34ef50c1fea --- aten/src/ATen/core/dispatch/OperatorEntry.cpp | 41 +++++++++++-------- aten/src/ATen/core/dispatch/OperatorEntry.h | 20 ++++----- .../op_registration/op_registration_test.cpp | 32 +++++++-------- 3 files changed, 50 insertions(+), 43 deletions(-) diff --git a/aten/src/ATen/core/dispatch/OperatorEntry.cpp b/aten/src/ATen/core/dispatch/OperatorEntry.cpp index a3cef61f4c21..f0d7bc6968ed 100644 --- a/aten/src/ATen/core/dispatch/OperatorEntry.cpp +++ b/aten/src/ATen/core/dispatch/OperatorEntry.cpp @@ -36,9 +36,13 @@ namespace { c10::optional schema_difference = findSchemaDifferences(from_def, inferred); if (schema_difference.has_value()) { TORCH_CHECK(false, - "In registration for ", toString(name), ": expected schema of operator to be \"", toString(from_def), "\" (", from_def_debug, "), ", - "but got inferred schema \"", toString(inferred), "\" (", inferred_debug, "). ", - *schema_difference); + "Inferred operator schema for a C++ kernel function doesn't match the expected function schema.\n" + " operator: ", toString(name), "\n", + " expected schema: ", toString(from_def), "\n", + " ", from_def_debug, "\n", + " inferred schema: ", toString(inferred), "\n", + " ", inferred_debug, "\n", + " reason: ", *schema_difference); } } } // anonymous namespace @@ -83,15 +87,19 @@ std::list::iterator OperatorEntry::registerKernel( // that would also invalidate the old TypedOperatorHandles. if (cpp_signature.has_value()) { if (cpp_signature_.has_value()) { - TORCH_INTERNAL_ASSERT(*cpp_signature == cpp_signature_->signature, - "Tried to register a kernel (", debug, ") for operator ", name_," (", - (this->schema_.has_value() ? this->schema_->debug : "no debug info"), - ") for dispatch key ", toString(dispatch_key), ", but the C++ function signature ", - cpp_signature->name(), " mismatched with a previous kernel (", cpp_signature_->debug, - ") that had the signature ", cpp_signature_->signature.name() + TORCH_CHECK(*cpp_signature == cpp_signature_->signature, + "\nMismatch in kernel C++ signatures\n", + " operator: ", (this->schema_.has_value() ? toString(this->schema_->schema) : toString(name_)), "\n", + " ", (this->schema_.has_value() ? this->schema_->debug : "no debug info"), "\n", + " kernel 1: ", cpp_signature_->signature.name(), "\n", + " dispatch key: ", toString(cpp_signature_->dispatch_key), "\n", + " ", cpp_signature_->debug, "\n", + " kernel 2: ", cpp_signature->name(), "\n", + " dispatch key: ", toString(dispatch_key), "\n", + " ", debug, "\n" ); } else { - cpp_signature_ = CppSignatureWithDebug { *cpp_signature, debug }; + cpp_signature_ = CppSignatureWithDebug { *cpp_signature, debug, dispatch_key }; } } @@ -105,12 +113,13 @@ std::list::iterator OperatorEntry::registerKernel( auto& k = dispatch_key.has_value() ? kernels_[*dispatch_key] : kernels_[DispatchKey::Math]; if (k.size() > 0) { - TORCH_WARN("Registering a kernel (", debug, ") for operator ", name_, " (", - (this->schema_.has_value() ? this->schema_->debug : "no debug info"), - ") for dispatch key ", toString(dispatch_key), - " that overwrote a previously registered kernel (", - (cpp_signature_.has_value() ? cpp_signature_->debug : "no debug info"), - ") with the same dispatch key for the same operator."); + TORCH_WARN("Overriding a previously registered kernel for the same operator and the same dispatch key\n", + " operator: ", (schema_.has_value() ? toString(schema_->schema) : toString(name_)), "\n", + " ", (this->schema_.has_value() ? this->schema_->debug : "no debug info"), "\n", + " dispatch key: ", toString(dispatch_key), "\n", + " previous kernel: ", (cpp_signature_.has_value() ? cpp_signature_->debug : "no debug info"), "\n", + " new kernel: ", debug + ); } if (manuallyBoxedKernel_.has_value()) { diff --git a/aten/src/ATen/core/dispatch/OperatorEntry.h b/aten/src/ATen/core/dispatch/OperatorEntry.h index 26506cb0f76f..79af2243d420 100644 --- a/aten/src/ATen/core/dispatch/OperatorEntry.h +++ b/aten/src/ATen/core/dispatch/OperatorEntry.h @@ -157,16 +157,15 @@ class CAFFE2_API OperatorEntry final { // Asserts that the given FuncType is correct for calling this operator in an unboxed way. template void assertSignatureIsCorrect() { - TORCH_INTERNAL_ASSERT(!cpp_signature_.has_value() || (CppSignature::make() == cpp_signature_->signature), - "Tried to access operator ", name_, " with a wrong signature. Accessed with ", - CppSignature::make().name(), - " but the operator was registered with ", - cpp_signature_->signature.name(), - " (schema: ", - (schema_.has_value() ? schema_->debug : "unknown debug info"), - ", kernel: ", - cpp_signature_->debug, - ") This likely happened in a call to OperatorHandle::typed(). Please make sure that the function signature matches the signature in the operator registration call." + TORCH_CHECK(!cpp_signature_.has_value() || (CppSignature::make() == cpp_signature_->signature), + "\nTried to access or call an operator with a wrong signature.\n", + " operator: ", (schema_.has_value() ? toString(schema_->schema) : toString(name_)), "\n", + " ", (schema_.has_value() ? schema_->debug : "unknown debug info"), "\n", + " correct signature: ", cpp_signature_->signature.name(), "\n", + " ", cpp_signature_->debug, "\n", + " accessed/called as: ", CppSignature::make().name(), "\n", + "This likely happened in a call to OperatorHandle::typed(). ", + "Please make sure that the function signature matches the signature in the operator registration call." ); } @@ -241,6 +240,7 @@ class CAFFE2_API OperatorEntry final { struct CppSignatureWithDebug { CppSignature signature; std::string debug; + c10::optional dispatch_key; }; c10::optional cpp_signature_; diff --git a/aten/src/ATen/core/op_registration/op_registration_test.cpp b/aten/src/ATen/core/op_registration/op_registration_test.cpp index 3fd8740d1ab1..9239e78133cd 100644 --- a/aten/src/ATen/core/op_registration/op_registration_test.cpp +++ b/aten/src/ATen/core/op_registration/op_registration_test.cpp @@ -308,10 +308,9 @@ TEST(OperatorRegistrationTest, givenMultipleKernelsWithSameDispatchKey_whenRegis testing::internal::CaptureStderr(); c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::DispatchKey::CPU)); std::string output = testing::internal::GetCapturedStderr(); - EXPECT_THAT(output, testing::HasSubstr("_test::dummy")); - EXPECT_THAT(output, testing::HasSubstr("CPU")); - EXPECT_THAT(output, testing::HasSubstr("overwrote a previously registered kernel ")); - EXPECT_THAT(output, testing::HasSubstr(" with the same dispatch key for the same operator")); + EXPECT_THAT(output, testing::HasSubstr("Overriding a previously registered kernel for the same operator and the same dispatch key")); + EXPECT_THAT(output, testing::HasSubstr("operator: _test::dummy(Tensor dummy) -> ()")); + EXPECT_THAT(output, testing::HasSubstr("dispatch key: CPU")); } TEST(OperatorRegistrationTest, givenMultipleKernelsWithSameDispatchKey_whenRegisteringInSameOpCall_thenFails) { @@ -347,10 +346,9 @@ TEST(OperatorRegistrationTest, givenMultipleCatchallKernels_whenRegistering_then testing::internal::CaptureStderr(); c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().catchAllKernel()); std::string output = testing::internal::GetCapturedStderr(); - EXPECT_THAT(output, testing::HasSubstr("_test::dummy")); - EXPECT_THAT(output, testing::HasSubstr("catch all")); - EXPECT_THAT(output, testing::HasSubstr("overwrote a previously registered kernel ")); - EXPECT_THAT(output, testing::HasSubstr(" with the same dispatch key for the same operator")); + EXPECT_THAT(output, testing::HasSubstr("Overriding a previously registered kernel for the same operator and the same dispatch key")); + EXPECT_THAT(output, testing::HasSubstr("operator: _test::dummy(Tensor dummy) -> ()")); + EXPECT_THAT(output, testing::HasSubstr("dispatch key: (catch all)")); } TEST(OperatorRegistrationTest, givenMultipleCatchallKernels_whenRegisteringInSameOpCall_thenFails) { @@ -703,7 +701,7 @@ TEST(OperatorRegistrationTest, whenRegisteringMismatchingKernelsInSameOpCall_the auto registrar1 = c10::RegisterOperators().op("_test::dummy", c10::RegisterOperators::options() .kernel(c10::DispatchKey::CPU) .kernel(c10::DispatchKey::CUDA, &called_kernel)); - }, "mismatched with a previous kernel"); + }, "Mismatch in kernel C++ signatures"); } void backend_fallback_kernel(const c10::OperatorHandle& op, c10::Stack* stack) { @@ -946,7 +944,7 @@ TEST(OperatorRegistrationTest, givenLambdaKernel_whenRegisteringWithMismatchingC expectThrows([] { auto registrar = c10::RegisterOperators().op("_test::dummy", c10::RegisterOperators::options() .kernel(DispatchKey::CPU, [] (const int64_t&) {})); - }, "mismatched with a previous kernel"); + }, "Mismatch in kernel C++ signatures"); } TEST(OperatorRegistrationTest, givenLambdaKernel_whenRegisteringCatchAllAndBackendWithMismatchingCppSignatures_thenFails) { @@ -955,7 +953,7 @@ TEST(OperatorRegistrationTest, givenLambdaKernel_whenRegisteringCatchAllAndBacke expectThrows([] { auto registrar = c10::RegisterOperators().op("_test::dummy", c10::RegisterOperators::options() .kernel(DispatchKey::CPU, [] (const int64_t&) {})); - }, "mismatched with a previous kernel"); + }, "Mismatch in kernel C++ signatures"); } TEST(OperatorRegistrationTest, givenLambdaKernel_whenRegisteringBackendAndCatchAllWithMismatchingCppSignatures_thenFails) { @@ -964,7 +962,7 @@ TEST(OperatorRegistrationTest, givenLambdaKernel_whenRegisteringBackendAndCatchA expectThrows([] { auto registrar = c10::RegisterOperators().op("_test::dummy", c10::RegisterOperators::options() .catchAllKernel([] (const int64_t&) {})); - }, "mismatched with a previous kernel"); + }, "Mismatch in kernel C++ signatures"); } TEST(OperatorRegistrationTest, givenLambdaKernel_whenAccessingWithMismatchingCppSignatures_thenFails) { @@ -973,7 +971,7 @@ TEST(OperatorRegistrationTest, givenLambdaKernel_whenAccessingWithMismatchingCpp expectThrows([] { c10::Dispatcher::singleton().findSchemaOrThrow("_test::dummy", "") .typed(); - }, "Tried to access operator _test::dummy with a wrong signature"); + }, "Tried to access or call an operator with a wrong signature.\n operator: _test::dummy(int _0) -> ()"); } TEST(OperatorRegistrationTest, givenLambdaKernel_whenAccessingCatchAllWithMismatchingCppSignatures_thenFails) { @@ -982,7 +980,7 @@ TEST(OperatorRegistrationTest, givenLambdaKernel_whenAccessingCatchAllWithMismat expectThrows([] { c10::Dispatcher::singleton().findSchemaOrThrow("_test::dummy", "") .typed(); - }, "Tried to access operator _test::dummy with a wrong signature"); + }, "Tried to access or call an operator with a wrong signature.\n operator: _test::dummy(int _0) -> ()"); } TEST(OperatorRegistrationTest, givenTorchLibrary_whenRegisteringWithMismatchingCppSignatures_thenFails) { @@ -991,7 +989,7 @@ TEST(OperatorRegistrationTest, givenTorchLibrary_whenRegisteringWithMismatchingC m.impl("dummy", DispatchKey::CPU, [] (int64_t) {}); expectThrows([&] { m.impl("dummy", DispatchKey::CUDA, [] (const int64_t&) {}); - }, "mismatched with a previous kernel"); + }, "Mismatch in kernel C++ signatures"); } TEST(OperatorRegistrationTest, givenTorchLibrary_whenAccessingWithMismatchingCppSignatures_thenFails) { @@ -1001,7 +999,7 @@ TEST(OperatorRegistrationTest, givenTorchLibrary_whenAccessingWithMismatchingCpp expectThrows([] { c10::Dispatcher::singleton().findSchemaOrThrow("_test::dummy", "") .typed(); - }, "Tried to access operator _test::dummy with a wrong signature"); + }, "Tried to access or call an operator with a wrong signature.\n operator: _test::dummy(int a) -> ()"); } TEST(OperatorRegistrationTest, givenTorchLibrary_whenAccessingCatchAllWithMismatchingCppSignatures_thenFails) { @@ -1010,7 +1008,7 @@ TEST(OperatorRegistrationTest, givenTorchLibrary_whenAccessingCatchAllWithMismat expectThrows([] { c10::Dispatcher::singleton().findSchemaOrThrow("_test::dummy", "") .typed(); - }, "Tried to access operator _test::dummy with a wrong signature"); + }, "Tried to access or call an operator with a wrong signature.\n operator: _test::dummy(int a) -> ()"); } /** From fcd44ce6982fc47d153f13cfc1f962e85357e386 Mon Sep 17 00:00:00 2001 From: skyline75489 Date: Wed, 11 Nov 2020 14:22:09 -0800 Subject: [PATCH 27/93] Add instruction on how to handle the potential linker error on Linux (#47593) Summary: The original issue is https://github.com/pytorch/pytorch/issues/16683, which contains a https://github.com/pytorch/pytorch/issues/16683#issuecomment-459982988 that suggests manually un-shadowing the `ld`. A better approach can be found at https://github.com/ContinuumIO/anaconda-issues/issues/11152#issuecomment-573120962, which suggests that using a newer version can effectively fix this. It took me quite some time to realize that this is in fact an issue caused by Anaconda. I think we should add it in README. Pull Request resolved: https://github.com/pytorch/pytorch/pull/47593 Reviewed By: ailzhang Differential Revision: D24866092 Pulled By: heitorschueroff fbshipit-source-id: c1f51864d23fd6f4f63a117496d8619053e35196 --- README.md | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index c6c1138747a2..ea21480eb3ce 100644 --- a/README.md +++ b/README.md @@ -149,7 +149,7 @@ They require JetPack 4.2 and above, and [@dusty-nv](https://github.com/dusty-nv) ### From Source -If you are installing from source, you will need Python 3.6 or later and a C++14 compiler. Also, we highly recommend installing an [Anaconda](https://www.anaconda.com/distribution/#download-section) environment. +If you are installing from source, you will need Python 3.6.2 or later and a C++14 compiler. Also, we highly recommend installing an [Anaconda](https://www.anaconda.com/distribution/#download-section) environment. You will get a high-quality BLAS library (MKL) and you get controlled dependency versions regardless of your Linux distro. Once you have [Anaconda](https://www.anaconda.com/distribution/#download-section) installed, here are the instructions. @@ -207,6 +207,16 @@ export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"} python setup.py install ``` +Note that if you are using [Anaconda](https://www.anaconda.com/distribution/#download-section), you may experience an error caused by the linker: + +```plaintext +build/temp.linux-x86_64-3.7/torch/csrc/stub.o: file not recognized: file format not recognized +collect2: error: ld returned 1 exit status +error: command 'g++' failed with exit status 1 +``` + +This is caused by `ld` from Conda environment shadowing the system `ld`. You should use a newer version of Python that fixes this issue. The recommended Python version is 3.6.10+, 3.7.6+ and 3.8.1+. + On macOS ```bash export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"} From f2b7c38735f2638fd32326cb40d5b5d458edc814 Mon Sep 17 00:00:00 2001 From: Facebook Community Bot Date: Wed, 11 Nov 2020 14:48:47 -0800 Subject: [PATCH 28/93] Automated submodule update: FBGEMM (#47605) Summary: This is an automated pull request to update the first-party submodule for [pytorch/FBGEMM](https://github.com/pytorch/FBGEMM). New submodule commit: https://github.com/pytorch/FBGEMM/commit/eb55572e5524f5a304f359bd7a193ba7ea57a91d Pull Request resolved: https://github.com/pytorch/pytorch/pull/47605 Test Plan: Ensure that CI jobs succeed on GitHub before landing. Reviewed By: jianyuh Differential Revision: D24833658 Pulled By: heitorschueroff fbshipit-source-id: 7a577c75d244a58d94c249c0e50992078a3b62cb --- third_party/fbgemm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/fbgemm b/third_party/fbgemm index 8eb6dcb23eee..92c5f37b430a 160000 --- a/third_party/fbgemm +++ b/third_party/fbgemm @@ -1 +1 @@ -Subproject commit 8eb6dcb23eee3b21c0e093a27810cb4a62dd3e27 +Subproject commit 92c5f37b430a66905bd03514c510ee236aca7cc0 From 0c54ea50bd6d3684e7dc13bb91f40d3d269dff8d Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Wed, 11 Nov 2020 15:08:17 -0800 Subject: [PATCH 29/93] [PyTorch] Avoid atomic refcounting in intrusive_ptr::make (#47100) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47100 Profiling with Linux `perf` shows that we spend at least 1% of our time doing this increment in our framework overhead benchmark. Here's the inline function breakdown for empty_cpu, which takes 6.91% of the total time: ``` - at::native::empty_cpu - 1.91% at::detail::make_tensor >, c10::DispatchKey, caffe2::TypeMeta&> (inlined) - 0.98% c10::make_intrusive, c10::intrusive_ptr >, c10::DispatchKey, caffe2::TypeMeta&> (inlined 0.97% c10::intrusive_ptr >::make >, c10::DispatchKey, caffe2::TypeMeta&> 0.84% intrusive_ptr > (inlined) - 1.44% c10::make_intrusive, c10::StorageImpl::use_byte_size_t, long&, c10::DataPtr, c10::Allocator*&, bool> (inlined) - 1.44% c10::intrusive_ptr >::make (inlined) 1.02% std::__atomic_base::operator++ (inlined) - 0.80% ~DataPtr (inlined) ~UniqueVoidPtr (inlined) ~unique_ptr (inlined) - 0.78% c10::TensorOptions::memory_format (inlined) - c10::TensorOptions::set_memory_format (inlined) - c10::optional::operator bool (inlined) c10::optional::initialized (inlined) ``` This change comes with a caveat: if we have constructors where `this` escapes to another thread before returning, we cannot make this assumption, because that other thread may have called `intrusive_ptr::make` already. I chose to just mandate that `instrusive_ptr_target`s's ctors hand back exclusive ownership of `this`, which seems like a reasonable requirement for a ctor anyway. If that turns out to be unacceptable, we could provide an opt-out from this optimization via a traits struct or similar template metaprogramming shenanigan. ghstack-source-id: 116368592 Test Plan: Run framework overhead benchmark. Results look promising, ranging from a tiny regression (? presumably noise) on the InPlace benchmark, 2.5% - 4% on OutOfPlace, to 9% on the empty benchmarks and 10-12% on the view benchmarks. Reviewed By: ezyang Differential Revision: D24606531 fbshipit-source-id: 1cf022063dab71cd1538535c72c4844d8dd7bb25 --- c10/util/intrusive_ptr.h | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/c10/util/intrusive_ptr.h b/c10/util/intrusive_ptr.h index 825b934852d4..f79e2dba3aa2 100644 --- a/c10/util/intrusive_ptr.h +++ b/c10/util/intrusive_ptr.h @@ -28,7 +28,8 @@ namespace raw { * performance because it does the refcounting intrusively * (i.e. in a member of the object itself). * Your class T needs to inherit from intrusive_ptr_target to allow it to be - * used in an intrusive_ptr. + * used in an intrusive_ptr. Your class's constructor should not allow + *`this` to escape to other threads or create an intrusive_ptr from `this`. */ // Note [Stack allocated intrusive_ptr_target safety] @@ -396,7 +397,22 @@ class intrusive_ptr final { */ template static intrusive_ptr make(Args&&... args) { - return intrusive_ptr(new TTarget(std::forward(args)...)); + auto result = intrusive_ptr(new TTarget(std::forward(args)...), raw::DontIncreaseRefcount{}); + + // We just created result.target_, so we know no other thread has + // access to it, so we know we needn't care about memory ordering. + // (On x86_64, a store with memory_order_relaxed generates a plain old + // `mov`, whereas an atomic increment does a lock-prefixed `add`, which is + // much more expensive: https://godbolt.org/z/eKPzj8.) + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + result.target_->refcount_ == 0 && result.target_->weakcount_ == 0, + "intrusive_ptr: Newly-created target had non-zero refcounts. Does its " + "constructor do something strange like incref or create an intrusive_ptr" + "from `this`?"); + result.target_->refcount_.store(1, std::memory_order_relaxed); + result.target_->weakcount_.store(1, std::memory_order_relaxed); + + return result; } /** From 32b4b51254dce410484cfc68052729227505ee24 Mon Sep 17 00:00:00 2001 From: Omkar Salpekar Date: Wed, 11 Nov 2020 15:18:54 -0800 Subject: [PATCH 30/93] [Docs] Minor doc fixes for init_process_group (#47644) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47644 Minor Update to the init_process_group docs. ghstack-source-id: 116441798 Test Plan: CI Reviewed By: jiayisuse, mrshenli Differential Revision: D24633432 fbshipit-source-id: fbd38dab464ee156d119f9f0b22ffd0e416c4fd7 --- torch/distributed/distributed_c10d.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 6f07c9412d5d..d97fa774ef30 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -401,11 +401,11 @@ def init_process_group(backend, asynchronously and the process will crash. ``NCCL_BLOCKING_WAIT`` will provide errors to the user which can be caught and handled, but due to its blocking nature, it has a performance overhead. On - the other hand, ``NCCL_ASYNC_ERROR_HANDLING`` has little + the other hand, ``NCCL_ASYNC_ERROR_HANDLING`` has very little performance overhead, but crashes the process on errors. This is done since CUDA execution is async and it is no longer safe to continue executing user code since failed async NCCL operations - might result in subsequent CUDA operations to run on corrupted + might result in subsequent CUDA operations running on corrupted data. Only one of these two environment variables should be set. group_name (str, optional, deprecated): Group name. From d4fa84bf5f3606c5ad14671408108df991ab1bca Mon Sep 17 00:00:00 2001 From: Michael Suo Date: Wed, 11 Nov 2020 15:23:10 -0800 Subject: [PATCH 31/93] Properly serialize types that only appear at function input (#47775) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47775 When serializing graphs, we check every node for named types referenced, so that we can register them as dependencies. We were skipping this check for the graph inputs themselves. Since types used at input are almost always used somewhere in the graph, we never noticed this gap until a user reported an issue with NamedTuples. Test Plan: Imported from OSS Reviewed By: jamesr66a Differential Revision: D24896289 Pulled By: suo fbshipit-source-id: 4ce76816cb7997a7b65e7cea152ea52ed8f27276 --- test/jit/test_save_load.py | 38 ++++++++++++++++++- torch/csrc/jit/serialization/python_print.cpp | 8 ++-- 2 files changed, 42 insertions(+), 4 deletions(-) diff --git a/test/jit/test_save_load.py b/test/jit/test_save_load.py index 23751c4fd92b..31b0124ae802 100644 --- a/test/jit/test_save_load.py +++ b/test/jit/test_save_load.py @@ -7,7 +7,7 @@ from itertools import product as product from torch import Tensor from torch.testing._internal.common_utils import TemporaryFileName -from typing import NamedTuple +from typing import NamedTuple, Optional # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) @@ -947,3 +947,39 @@ def forward(self, x): script_module = torch.jit.script(Foo()) with self.assertRaises(RuntimeError): script_module.save("NonExist/path/test.pt") + + def test_save_namedtuple_input_only(self): + """ + Even if a NamedTuple is only used as an input argument, saving and + loading should work correctly. + """ + global FooTuple # see [local resolution in python] + + class FooTuple(NamedTuple): + a: int + + class MyModule(torch.nn.Module): + def forward(self, x: FooTuple) -> torch.Tensor: + return torch.tensor(3) + + m_loaded = self.getExportImportCopy(torch.jit.script(MyModule())) + output = m_loaded(FooTuple(a=5)) + self.assertEqual(output, torch.tensor(3)) + + def test_save_namedtuple_output_only(self): + """ + Even if a NamedTuple is only used as an output argument, saving and + loading should work correctly. + """ + global FooTuple # see [local resolution in python] + + class FooTuple(NamedTuple): + a: int + + class MyModule(torch.nn.Module): + def forward(self) -> Optional[FooTuple]: + return None + + m_loaded = self.getExportImportCopy(torch.jit.script(MyModule())) + output = m_loaded() + self.assertEqual(output, None) diff --git a/torch/csrc/jit/serialization/python_print.cpp b/torch/csrc/jit/serialization/python_print.cpp index 9803829eb683..b9b1d60640c2 100644 --- a/torch/csrc/jit/serialization/python_print.cpp +++ b/torch/csrc/jit/serialization/python_print.cpp @@ -1255,6 +1255,7 @@ struct PythonPrintImpl { body_ << "def " << func.name() << "("; auto param_it = graph.inputs().begin(); for (const Argument& arg : schema.arguments()) { + registerClassDependencies(arg.type()); std::string arg_name = genName(arg.name()); if (param_it == graph.inputs().begin()) { // the first argument may omit its type when it is implied by context @@ -1273,9 +1274,10 @@ struct PythonPrintImpl { assignValue(*param_it++, arg_name); } - body_ << ") -> " - << schema.returns().at(0).type()->annotation_str(type_printer_) - << ":\n"; + const auto& returnType = schema.returns().at(0).type(); + body_ << ") -> " << returnType->annotation_str(type_printer_) << ":\n"; + registerClassDependencies(returnType); + printBody(graph.block()); } From 545f624a4a18b49c603e49768201cd9971e3f931 Mon Sep 17 00:00:00 2001 From: Akshit Khurana Date: Wed, 11 Nov 2020 15:54:07 -0800 Subject: [PATCH 32/93] Mark overriden Tensor method `override` (#47198) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47198 Fixes: ```xplat/caffe2/aten/src/ATen/native/xnnpack/OpContext.h:77:10: error: 'run' overrides a member function but is not marked 'override' [-Werror,-Winconsistent-missing-override] Tensor run(const Tensor& input);``` Test Plan: CI tests Reviewed By: kimishpatel Differential Revision: D24678573 fbshipit-source-id: 244769cc36d3c1126973a67441aa2d06d2b83b9c --- aten/src/ATen/native/xnnpack/OpContext.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aten/src/ATen/native/xnnpack/OpContext.h b/aten/src/ATen/native/xnnpack/OpContext.h index e696ad3aa81d..a3902dab7b67 100644 --- a/aten/src/ATen/native/xnnpack/OpContext.h +++ b/aten/src/ATen/native/xnnpack/OpContext.h @@ -70,7 +70,7 @@ class XNNPackLinearOpContext final : public LinearOpContext { output_max_ = max; } - Tensor run(const Tensor& input); + Tensor run(const Tensor& input) override; static c10::intrusive_ptr create_context( Tensor&& weight, From a0c4aae3d59a3637e7b39acdaf3c231414dc6e87 Mon Sep 17 00:00:00 2001 From: Akshit Khurana Date: Wed, 11 Nov 2020 15:54:07 -0800 Subject: [PATCH 33/93] Free original weight after prepacking in XNNPACK based op (#46541) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/46541 When weights are prepacked XNNPACK packs the weights in a separate memory. After that original weights are not needed for inference. Having those weights lying around increase memory footprint, so we would like to remove the original weights once prepacking is done. Test Plan: buck test //caffe2/aten:mobile_memory_cleanup Reviewed By: kimishpatel Differential Revision: D24280928 fbshipit-source-id: 90ffc53b1eabdc545a3ccffcd17fa3137d500cbb --- aten/src/ATen/native/xnnpack/OpContext.cpp | 30 +++++++++++++++ aten/src/ATen/native/xnnpack/OpContext.h | 15 ++++++++ aten/src/ATen/test/CMakeLists.txt | 1 + aten/src/ATen/test/mobile_memory_cleanup.cpp | 39 ++++++++++++++++++++ 4 files changed, 85 insertions(+) create mode 100644 aten/src/ATen/test/mobile_memory_cleanup.cpp diff --git a/aten/src/ATen/native/xnnpack/OpContext.cpp b/aten/src/ATen/native/xnnpack/OpContext.cpp index fe78dcda1f99..fe525cb2df4d 100644 --- a/aten/src/ATen/native/xnnpack/OpContext.cpp +++ b/aten/src/ATen/native/xnnpack/OpContext.cpp @@ -27,9 +27,19 @@ XNNPackLinearOpContext::create_context( output_max ? output_max->to() : xnnpack::ContextLinear::kMax) ); + if (at::globalContext().releaseWeightsWhenPrepacking()) { + linear_op_context->free_orig_weight_and_bias(); + } + return linear_op_context; } +void XNNPackLinearOpContext::free_orig_weight_and_bias() { + orig_weight_and_bias_freed_ = true; + orig_weight_.reset(); + orig_bias_.reset(); +} + Tensor XNNPackLinearOpContext::run(const Tensor& input) { return xnnpack::internal::linear::run(op_context_, input); } @@ -70,6 +80,10 @@ XNNPackConv2dOpContext::create_context(at::Tensor&& weight, output_max, std::move(op_context)); + if (at::globalContext().releaseWeightsWhenPrepacking()) { + conv2d_op_context->free_orig_weight_and_bias(); + } + return conv2d_op_context; } @@ -111,6 +125,10 @@ XNNPackTransposeConv2dOpContext::create_context(at::Tensor&& weight, output_max, std::move(op_context)); + if (at::globalContext().releaseWeightsWhenPrepacking()) { + conv2d_op_context->free_orig_weight_and_bias(); + } + return conv2d_op_context; } @@ -122,6 +140,18 @@ Tensor XNNPackTransposeConv2dOpContext::run(const Tensor& input) { return xnnpack::internal::convolution2d::run(op_context_, input); } +void XNNPackConv2dOpContext::free_orig_weight_and_bias() { + orig_weight_and_bias_freed_ = true; + orig_weight_.reset(); + orig_bias_.reset(); +} + +void XNNPackTransposeConv2dOpContext::free_orig_weight_and_bias() { + orig_weight_and_bias_freed_ = true; + orig_weight_.reset(); + orig_bias_.reset(); +} + } // namespace xnnpack } // namespace native } // namespace at diff --git a/aten/src/ATen/native/xnnpack/OpContext.h b/aten/src/ATen/native/xnnpack/OpContext.h index a3902dab7b67..e26c3383d6a6 100644 --- a/aten/src/ATen/native/xnnpack/OpContext.h +++ b/aten/src/ATen/native/xnnpack/OpContext.h @@ -43,13 +43,16 @@ class LinearOpContext : public torch::jit::CustomClassHolder { c10::optional orig_bias_; c10::optional output_min_; c10::optional output_max_; + bool orig_weight_and_bias_freed_; public: SerializationTypeLinearPrePack unpack() { + TORCH_CHECK(!orig_weight_and_bias_freed_, "Original weight and bias have been freed"); return std::make_tuple(orig_weight_, orig_bias_, output_min_, output_max_); } virtual Tensor run(const Tensor& input) = 0; + virtual void free_orig_weight_and_bias() = 0; }; class XNNPackLinearOpContext final : public LinearOpContext { @@ -68,9 +71,11 @@ class XNNPackLinearOpContext final : public LinearOpContext { orig_bias_ = std::move(bias); output_min_ = min; output_max_ = max; + orig_weight_and_bias_freed_ = false; } Tensor run(const Tensor& input) override; + void free_orig_weight_and_bias() override; static c10::intrusive_ptr create_context( Tensor&& weight, @@ -89,9 +94,11 @@ class Conv2dOpContext : public torch::jit::CustomClassHolder { int64_t groups_; c10::optional output_min_; c10::optional output_max_; + bool orig_weight_and_bias_freed_; public: SerializationTypeConv2dPrePack unpack() { + TORCH_CHECK(!orig_weight_and_bias_freed_, "Original weight and bias have been freed"); return std::make_tuple( orig_weight_, orig_bias_, @@ -104,6 +111,7 @@ class Conv2dOpContext : public torch::jit::CustomClassHolder { } virtual Tensor run(const Tensor& input) = 0; + virtual void free_orig_weight_and_bias() = 0; }; class TransposeConv2dOpContext : public torch::jit::CustomClassHolder { @@ -117,9 +125,11 @@ class TransposeConv2dOpContext : public torch::jit::CustomClassHolder { int64_t groups_; c10::optional output_min_; c10::optional output_max_; + bool orig_weight_and_bias_freed_; public: SerializationTypeTransposeConv2dPrePack unpack() { + TORCH_CHECK(!orig_weight_and_bias_freed_, "Original weight and bias have been freed"); return std::make_tuple( orig_weight_, orig_bias_, @@ -133,6 +143,7 @@ class TransposeConv2dOpContext : public torch::jit::CustomClassHolder { } virtual Tensor run(const Tensor& input) = 0; + virtual void free_orig_weight_and_bias() = 0; }; class XNNPackConv2dOpContext final : public Conv2dOpContext { @@ -159,9 +170,11 @@ class XNNPackConv2dOpContext final : public Conv2dOpContext { groups_ = groups; output_min_ = min; output_max_ = max; + orig_weight_and_bias_freed_ = false; } Tensor run(const Tensor& input) override; + void free_orig_weight_and_bias() override; static c10::intrusive_ptr create_context( Tensor&& weight, @@ -200,9 +213,11 @@ class XNNPackTransposeConv2dOpContext final : public TransposeConv2dOpContext { groups_ = groups; output_min_ = min; output_max_ = max; + orig_weight_and_bias_freed_ = false; } Tensor run(const Tensor& input) override; + void free_orig_weight_and_bias() override; static c10::intrusive_ptr create_context( Tensor&& weight, diff --git a/aten/src/ATen/test/CMakeLists.txt b/aten/src/ATen/test/CMakeLists.txt index 067902c0a3b7..50bdde11b0c9 100644 --- a/aten/src/ATen/test/CMakeLists.txt +++ b/aten/src/ATen/test/CMakeLists.txt @@ -29,6 +29,7 @@ list(APPEND ATen_CPU_TEST_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/tensor_iterator_test.cpp ${CMAKE_CURRENT_SOURCE_DIR}/math_kernel_test.cpp ${CMAKE_CURRENT_SOURCE_DIR}/memory_overlapping_test.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/mobile_memory_cleanup.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpu_generator_test.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpu_profiling_allocator_test.cpp ${CMAKE_CURRENT_SOURCE_DIR}/pow_test.cpp diff --git a/aten/src/ATen/test/mobile_memory_cleanup.cpp b/aten/src/ATen/test/mobile_memory_cleanup.cpp new file mode 100644 index 000000000000..8682fd0a4f15 --- /dev/null +++ b/aten/src/ATen/test/mobile_memory_cleanup.cpp @@ -0,0 +1,39 @@ +#include + +#include +#include + +using namespace torch::jit; + +#ifdef USE_XNNPACK + +TEST(MemoryCleanUp, NoErrorWithoutRelease) { + Module m("m"); + m.register_parameter("weight", torch::ones({20, 1, 5, 5}), false); + m.register_parameter("bias", torch::ones({20}), false); + m.define(R"( + def forward(self, input): + return torch._convolution(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1, False, False, True, True) + )"); + m.eval(); + auto m_optimized = optimizeForMobile(m); + std::stringstream ss; + EXPECT_NO_THROW(m_optimized.save(ss)); +} + +TEST(MemoryCleanUp, UnpackError) { + at::globalContext().setReleaseWeightsWhenPrepacking(true); + Module m("m"); + m.register_parameter("weight", torch::ones({20, 1, 5, 5}), false); + m.register_parameter("bias", torch::ones({20}), false); + m.define(R"( + def forward(self, input): + return torch._convolution(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1, False, False, True, True) + )"); + m.eval(); + auto m_optimized = optimizeForMobile(m); + std::stringstream ss; + EXPECT_ANY_THROW(m_optimized.save(ss)); +} + +#endif From 52ec8b9340bd7fce956817dca75498e5723d1b80 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Wed, 11 Nov 2020 16:05:56 -0800 Subject: [PATCH 34/93] Added CUDA support for complex input for torch.triangular_solve (#46916) Summary: `torch.triangular_solve` now works for complex inputs on GPU. I moved the existing tests to `test_linalg.py` and modified them to test complex and float32 dtypes. Ref. https://github.com/pytorch/pytorch/issues/33152 Pull Request resolved: https://github.com/pytorch/pytorch/pull/46916 Reviewed By: navahgar, agolynski Differential Revision: D24706647 Pulled By: anjali411 fbshipit-source-id: fe780eac93d2ae1b2549539bb385e5fac25213b3 --- .../ATen/native/cuda/BatchLinearAlgebra.cu | 50 ++++- test/test_autograd.py | 21 +- test/test_linalg.py | 185 +++++++++++++++++- test/test_torch.py | 120 ------------ torch/_torch_docs.py | 2 + .../_internal/common_methods_invocations.py | 1 + 6 files changed, 235 insertions(+), 144 deletions(-) diff --git a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu index c0bc2d915bd0..5379d38fa43f 100644 --- a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu +++ b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu @@ -586,6 +586,30 @@ void magmaTriangularSolve( AT_CUDA_CHECK(cudaGetLastError()); } +template<> +void magmaTriangularSolve>( + magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n, + c10::complex* dA, magma_int_t ldda, c10::complex* dB, magma_int_t lddb) { + MagmaStreamSyncGuard guard; + magmaDoubleComplex alpha({1, 0}); + magma_ztrsm(MagmaLeft, uplo, trans, diag, m, n, alpha, + reinterpret_cast(dA), ldda, + reinterpret_cast(dB), lddb); + AT_CUDA_CHECK(cudaGetLastError()); +} + +template<> +void magmaTriangularSolve>( + magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n, + c10::complex* dA, magma_int_t ldda, c10::complex* dB, magma_int_t lddb) { + MagmaStreamSyncGuard guard; + magmaFloatComplex alpha({1, 0}); + magma_ctrsm(MagmaLeft, uplo, trans, diag, m, n, alpha, + reinterpret_cast(dA), ldda, + reinterpret_cast(dB), lddb); + AT_CUDA_CHECK(cudaGetLastError()); +} + template<> void magmaTriangularSolveBatched( magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n, @@ -604,6 +628,30 @@ void magmaTriangularSolveBatched( AT_CUDA_CHECK(cudaGetLastError()); } +template<> +void magmaTriangularSolveBatched>( + magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n, + c10::complex** dA_array, magma_int_t ldda, c10::complex** dB_array, magma_int_t lddb, magma_int_t batchsize, + const MAGMAQueue& magma_queue) { + magmaDoubleComplex alpha({1, 0}); + magmablas_ztrsm_batched(MagmaLeft, uplo, trans, diag, m, n, alpha, + reinterpret_cast(dA_array), ldda, + reinterpret_cast(dB_array), lddb, batchsize, magma_queue.get_queue()); + AT_CUDA_CHECK(cudaGetLastError()); +} + +template<> +void magmaTriangularSolveBatched>( + magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n, + c10::complex** dA_array, magma_int_t ldda, c10::complex** dB_array, magma_int_t lddb, magma_int_t batchsize, + const MAGMAQueue& magma_queue) { + magmaFloatComplex alpha({1, 0}); + magmablas_ctrsm_batched(MagmaLeft, uplo, trans, diag, m, n, alpha, + reinterpret_cast(dA_array), ldda, + reinterpret_cast(dB_array), lddb, batchsize, magma_queue.get_queue()); + AT_CUDA_CHECK(cudaGetLastError()); +} + template<> inline magma_int_t magmaGeqrfOptimalBlocksize(magma_int_t m, magma_int_t n) { return magma_get_dgeqrf_nb(m, n); @@ -1483,7 +1531,7 @@ std::tuple _triangular_solve_helper_cuda(const Tensor& self, con bool upper, bool transpose, bool unitriangular) { auto self_working_copy = cloneBatchedColumnMajor(self); auto A_working_copy = cloneBatchedColumnMajor(A); - AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "triangular_solve_cuda", [&]{ + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "triangular_solve_cuda", [&]{ apply_triangular_solve(self_working_copy, A_working_copy, upper, transpose, unitriangular); }); return std::tuple(self_working_copy, A_working_copy); diff --git a/test/test_autograd.py b/test/test_autograd.py index 3003d4c06403..34ed8b72867e 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -2777,25 +2777,6 @@ def func(A, upper): for upper, dims in product([True, False], [(3, 3), (5, 5)]): _test_with_size(upper, dims) - @skipIfNoLapack - def test_triangular_solve(self): - def run_test(A_dims, B_dims, dtype): - A = torch.rand(*A_dims, dtype=dtype).requires_grad_() - b = torch.rand(*B_dims, dtype=dtype).requires_grad_() - - for upper, transpose, unitriangular in product((True, False), repeat=3): - def func(A, b): - return torch.triangular_solve(b, A, upper, transpose, unitriangular) - - gradcheck(func, [A, b]) - gradgradcheck(func, [A, b]) - - for dtype in (torch.double, torch.cdouble): - run_test((3, 3), (3, 4), dtype) - run_test((3, 3), (3, 2), dtype) - run_test((2, 3, 3), (2, 3, 4), dtype) - run_test((2, 3, 3), (2, 3, 2), dtype) - @unittest.skipIf(not TEST_MKL, "PyTorch is built without MKL support") def test_fft_ifft_rfft_irfft(self): def _test_complex(sizes, signal_ndim): @@ -5006,7 +4987,7 @@ def run_functional_checks(test_case, test_name, name, apply_fn, run_grad_checks, 'eq_', 'ne_', 'add', '__radd__', 'sum', 'conj', 'sin', 'cos', 'mul', 'sinh', 'cosh', '__rmul__', 'sgn', 'abs', 'dot', 'vdot', 'tensor_split', 'matmul', 'bmm', 'mv', 'ger', 'diagonal', 'atan', 'angle', 'tanh', 'fill_', 'sub', - 'exp', 'mean', 'inverse'] + separate_complex_tests + 'exp', 'mean', 'inverse', 'triangular_solve'] + separate_complex_tests # this list corresponds to cases that are not currently implemented skip_cuda_list = ['bmm_complex', 'matmul_4d_4d_complex', 'inverse_batched_complex'] diff --git a/test/test_linalg.py b/test/test_linalg.py index d33d3cfe98ce..a64ea5302447 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -6,13 +6,13 @@ from random import randrange from torch.testing._internal.common_utils import \ - (TestCase, run_tests, TEST_NUMPY, IS_MACOS, IS_WINDOWS, slowTest, TEST_WITH_ASAN, make_tensor) + (TestCase, run_tests, TEST_NUMPY, TEST_SCIPY, IS_MACOS, IS_WINDOWS, slowTest, TEST_WITH_ASAN, make_tensor) from torch.testing._internal.common_device_type import \ (instantiate_device_type_tests, dtypes, dtypesIfCUDA, - onlyCUDA, skipCUDAIfNoMagma, skipCPUIfNoLapack, precisionOverride, + onlyCUDA, onlyCPU, skipCUDAIfNoMagma, skipCPUIfNoLapack, precisionOverride, skipCUDAIfNoMagmaAndNoCusolver, onlyOnCPUAndCUDA) from torch.testing._internal.jit_metaprogramming_utils import gen_script_fn_and_args -from torch.autograd import gradcheck +from torch.autograd import gradcheck, gradgradcheck if TEST_NUMPY: import numpy as np @@ -1462,6 +1462,185 @@ def check(equation, operands, regex, exception=RuntimeError): check('a, ba', [x, y], r'operands do not broadcast with remapped shapes \[original->remapped\]: ' r'\[2\]->\[1, 2\] \[2, 3\]->\[2, 3\]') + def triangular_solve_test_helper(self, A_dims, b_dims, upper, unitriangular, + device, dtype): + triangle_function = torch.triu if upper else torch.tril + b = torch.randn(*b_dims, dtype=dtype, device=device) + A = torch.randn(*A_dims, dtype=dtype, device=device) + # TODO(@ivanyashchuk): remove this once batched matmul is avaiable on CUDA for complex dtypes + if self.device_type == 'cuda' and dtype.is_complex: + A_tmp = torch.empty_like(A).view(-1, *A_dims[-2:]) + for A_i, A_tmp_i in zip(A.contiguous().view(-1, *A_dims[-2:]), A_tmp): + torch.matmul(A_i, A_i.t(), out=A_tmp_i) + A = A_tmp.view(*A_dims) + else: + # create positive definite matrix + A = torch.matmul(A, A.transpose(-2, -1)) + A_triangular = triangle_function(A) + if unitriangular: + A_triangular.diagonal(dim1=-2, dim2=-1).fill_(1.) + return b, A_triangular + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3, + torch.float64: 1e-8, torch.complex128: 1e-8}) + def test_triangular_solve(self, device, dtype): + for (k, n), (upper, unitriangular, transpose) in itertools.product(zip([2, 3, 5], [3, 5, 7]), + itertools.product([True, False], repeat=3)): + b, A = self.triangular_solve_test_helper((n, n), (n, k), upper, + unitriangular, device, dtype) + x = torch.triangular_solve(b, A, upper=upper, unitriangular=unitriangular, transpose=transpose)[0] + if transpose: + self.assertEqual(b, A.t().mm(x)) + else: + self.assertEqual(b, A.mm(x)) + + @skipCPUIfNoLapack + @skipCUDAIfNoMagma + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3, + torch.float64: 1e-8, torch.complex128: 1e-8}) + def test_triangular_solve_batched(self, device, dtype): + def triangular_solve_batch_helper(A_dims, b_dims, upper, unitriangular, transpose): + b, A = self.triangular_solve_test_helper(A_dims, b_dims, upper, + unitriangular, device, dtype) + x_exp_list = [] + for i in range(b_dims[0]): + x_exp_list.append(torch.triangular_solve(b[i], A[i], upper=upper, + unitriangular=unitriangular, + transpose=transpose)[0]) + x_exp = torch.stack(x_exp_list) # Stacked output + x_act = torch.triangular_solve(b, A, upper=upper, + unitriangular=unitriangular, + transpose=transpose)[0] # Actual output + self.assertEqual(x_act, x_exp) # Equality check + if transpose: + A = A.transpose(-2, -1) + + # TODO(@ivanyashchuk): remove this once batched matmul is avaiable on CUDA for complex dtypes + if self.device_type == 'cuda' and dtype.is_complex: + Ax = torch.empty_like(x_act).view(-1, *b_dims[-2:]) + for A_i, x_i, Ax_i in zip(A.contiguous().view(-1, *A_dims[-2:]), + x_act.contiguous().view(-1, *b_dims[-2:]), Ax): + torch.matmul(A_i, x_i, out=Ax_i) + Ax = Ax.view(*x_act.shape) + else: + Ax = torch.matmul(A, x_act) + self.assertEqual(b, Ax) + + for (upper, unitriangular, transpose), batchsize in itertools.product(itertools.product( + [True, False], repeat=3), [1, 3, 4]): + triangular_solve_batch_helper((batchsize, 5, 5), (batchsize, 5, 10), + upper, unitriangular, transpose) + + + @slowTest + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3, + torch.float64: 1e-8, torch.complex128: 1e-8}) + def test_triangular_solve_batched_many_batches(self, device, dtype): + for upper, transpose, unitriangular in itertools.product([True, False], repeat=3): + # test batched A case + b, A = self.triangular_solve_test_helper((256, 256, 5, 5), (5, 1), + upper, unitriangular, device, dtype) + x, _ = torch.triangular_solve(b, A, + upper=upper, transpose=transpose, unitriangular=unitriangular) + if transpose: + A = A.transpose(-2, -1) + + # TODO(@ivanyashchuk): remove this once batched matmul is avaiable on CUDA for complex dtypes + if self.device_type == 'cuda' and dtype.is_complex: + Ax = torch.empty_like(x).view(-1, 5, 1) + for A_i, x_i, Ax_i in zip(A.contiguous().view(-1, 5, 5), x.contiguous().view(-1, 5, 1), Ax): + torch.matmul(A_i, x_i, out=Ax_i) + Ax = Ax.view(256, 256, 5, 1) + else: + Ax = torch.matmul(A, x) + + rtol = 1e-2 if dtype in [torch.float32, torch.complex64] else self.precision + self.assertEqual(Ax, b.expand_as(Ax), atol=self.precision, rtol=rtol) + + # test batched b case + b, A = self.triangular_solve_test_helper((3, 3), (512, 512, 3, 1), + upper, unitriangular, device, dtype) + x, _ = torch.triangular_solve(b, A, upper=upper, transpose=transpose, + unitriangular=unitriangular) + if transpose: + A = A.transpose(-2, -1) + + self.assertEqual(torch.matmul(A, x), b) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @unittest.skipIf(not TEST_SCIPY, "SciPy not found") + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + def test_triangular_solve_batched_broadcasting(self, device, dtype): + from scipy.linalg import solve_triangular as tri_solve + + def scipy_tri_solve_batched(A, B, upper, trans, diag): + batch_dims_A, batch_dims_B = A.shape[:-2], B.shape[:-2] + single_dim_A, single_dim_B = A.shape[-2:], B.shape[-2:] + expand_dims = tuple(torch._C._infer_size(torch.Size(batch_dims_A), + torch.Size(batch_dims_B))) + expand_A = np.broadcast_to(A, expand_dims + single_dim_A) + expand_B = np.broadcast_to(B, expand_dims + single_dim_B) + flat_A = expand_A.reshape((-1,) + single_dim_A) + flat_B = expand_B.reshape((-1,) + single_dim_B) + flat_X = np.vstack([tri_solve(a, b, lower=(not upper), trans=int(trans), unit_diagonal=diag) + for a, b in zip(flat_A, flat_B)]) + return flat_X.reshape(expand_B.shape) + + def run_test(A_dims, b_dims, device, upper, transpose, unitriangular): + b, A = self.triangular_solve_test_helper(A_dims, b_dims, upper, + unitriangular, device, dtype) + x_exp = torch.as_tensor(scipy_tri_solve_batched(A.cpu().numpy(), b.cpu().numpy(), + upper, transpose, unitriangular)) + x = torch.triangular_solve(b, A, upper=upper, transpose=transpose, unitriangular=unitriangular)[0] + + self.assertEqual(x, x_exp.to(device)) + + for upper, transpose, unitriangular in itertools.product([True, False], repeat=3): + # test against scipy.linalg.solve_triangular + run_test((2, 1, 3, 4, 4), (2, 1, 3, 4, 6), device, upper, transpose, unitriangular) # no broadcasting + run_test((2, 1, 3, 4, 4), (4, 6), device, upper, transpose, unitriangular) # broadcasting b + run_test((4, 4), (2, 1, 3, 4, 2), device, upper, transpose, unitriangular) # broadcasting A + run_test((1, 3, 1, 4, 4), (2, 1, 3, 4, 5), device, upper, transpose, unitriangular) # broadcasting A & b + + @onlyCPU + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + def test_triangular_solve_singular(self, device, dtype): + b = torch.rand(3, 1, dtype=dtype, device=device) + A = torch.eye(3, 3, dtype=dtype, device=device) + A[-1, -1] = 0 # Now A is singular + err_str = r"triangular_solve_cpu: U\(3,3\) is zero, singular U\." + with self.assertRaisesRegex(RuntimeError, err_str): + torch.triangular_solve(b, A) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float64, torch.complex128) + def test_triangular_solve_autograd(self, device, dtype): + def run_test(A_dims, B_dims): + A = torch.rand(*A_dims, dtype=dtype).requires_grad_() + b = torch.rand(*B_dims, dtype=dtype).requires_grad_() + + for upper, transpose, unitriangular in itertools.product((True, False), repeat=3): + def func(A, b): + return torch.triangular_solve(b, A, upper, transpose, unitriangular) + + gradcheck(func, [A, b]) + gradgradcheck(func, [A, b]) + + run_test((3, 3), (3, 4)) + run_test((3, 3), (3, 2)) + run_test((2, 3, 3), (2, 3, 4)) + run_test((2, 3, 3), (2, 3, 2)) + instantiate_device_type_tests(TestLinalg, globals()) if __name__ == '__main__': diff --git a/test/test_torch.py b/test/test_torch.py index 143dadf09ffe..2e310f34be6c 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -10168,126 +10168,6 @@ def test_geqrf(self, device): self.assertEqual(b, b_placeholder) self.assertEqual(c, c_placeholder) - def triangular_solve_test_helper(self, A_dims, b_dims, upper, unitriangular, - device, dtype): - triangle_function = torch.triu if upper else torch.tril - b = torch.randn(*b_dims, dtype=dtype, device=device) - A = torch.randn(*A_dims, dtype=dtype, device=device) - A_triangular = triangle_function(A) - if unitriangular: - A_triangular.diagonal(dim1=-2, dim2=-1).fill_(1.) - return b, A_triangular - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(torch.double) - def test_triangular_solve(self, device, dtype): - for (k, n), (upper, unitriangular, transpose) in product(zip([2, 3, 5], [3, 5, 7]), - product([True, False], repeat=3)): - b, A = self.triangular_solve_test_helper((n, n), (n, k), upper, - unitriangular, device, dtype) - x = torch.triangular_solve(b, A, upper=upper, unitriangular=unitriangular, transpose=transpose)[0] - if transpose: - self.assertLessEqual(b.dist(A.t().mm(x)), 4e-12) - else: - self.assertLessEqual(b.dist(A.mm(x)), 4e-12) - - @skipCPUIfNoLapack - @skipCUDAIfNoMagma - @dtypes(torch.double) - def test_triangular_solve_batched(self, device, dtype): - def triangular_solve_batch_helper(A_dims, b_dims, upper, unitriangular, transpose): - b, A = self.triangular_solve_test_helper(A_dims, b_dims, upper, - unitriangular, device, dtype) - x_exp_list = [] - for i in range(b_dims[0]): - x_exp_list.append(torch.triangular_solve(b[i], A[i], upper=upper, - unitriangular=unitriangular, - transpose=transpose)[0]) - x_exp = torch.stack(x_exp_list) # Stacked output - x_act = torch.triangular_solve(b, A, upper=upper, - unitriangular=unitriangular, - transpose=transpose)[0] # Actual output - self.assertEqual(x_act, x_exp) # Equality check - if transpose: - self.assertLessEqual(b.dist(torch.matmul(A.transpose(-2, -1), x_act)), 1e-11) # Correctness check - else: - self.assertLessEqual(b.dist(torch.matmul(A, x_act)), 1e-11) # Correctness check - - for (upper, unitriangular, transpose), batchsize in product(product([True, False], repeat=3), [1, 3, 4]): - triangular_solve_batch_helper((batchsize, 5, 5), (batchsize, 5, 10), - upper, unitriangular, transpose) - - - @slowTest - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(torch.double) - def test_triangular_solve_batched_many_batches(self, device, dtype): - for upper, transpose, unitriangular in product([True, False], repeat=3): - b, A = self.triangular_solve_test_helper((256, 256, 5, 5), (5, 1), - upper, unitriangular, device, dtype) - x, _ = torch.triangular_solve(b, A, - upper=upper, transpose=transpose, unitriangular=unitriangular) - if transpose: - A = A.transpose(-2, -1) - self.assertEqual(torch.matmul(A, x), b.expand(A.shape[:-2] + (5, 1))) - - b, A = self.triangular_solve_test_helper((3, 3), (512, 512, 3, 1), - upper, unitriangular, device, dtype) - x, _ = torch.triangular_solve(b, A, upper=upper, transpose=transpose, - unitriangular=unitriangular) - if transpose: - A = A.transpose(-2, -1) - self.assertEqual(torch.matmul(A, x), b) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @unittest.skipIf(not TEST_SCIPY, "SciPy not found") - @dtypes(torch.double) - def test_triangular_solve_batched_broadcasting(self, device, dtype): - from scipy.linalg import solve_triangular as tri_solve - - def scipy_tri_solve_batched(A, B, upper, trans, diag): - batch_dims_A, batch_dims_B = A.shape[:-2], B.shape[:-2] - single_dim_A, single_dim_B = A.shape[-2:], B.shape[-2:] - expand_dims = tuple(torch._C._infer_size(torch.Size(batch_dims_A), - torch.Size(batch_dims_B))) - expand_A = np.broadcast_to(A, expand_dims + single_dim_A) - expand_B = np.broadcast_to(B, expand_dims + single_dim_B) - flat_A = expand_A.reshape((-1,) + single_dim_A) - flat_B = expand_B.reshape((-1,) + single_dim_B) - flat_X = np.vstack([tri_solve(a, b, lower=(not upper), trans=int(trans), unit_diagonal=diag) - for a, b in zip(flat_A, flat_B)]) - return flat_X.reshape(expand_B.shape) - - def run_test(A_dims, b_dims, device, upper, transpose, unitriangular): - b, A = self.triangular_solve_test_helper(A_dims, b_dims, upper, - unitriangular, device, dtype) - x_exp = torch.as_tensor(scipy_tri_solve_batched(A.cpu().numpy(), b.cpu().numpy(), - upper, transpose, unitriangular)) - x = torch.triangular_solve(b, A, upper=upper, transpose=transpose, unitriangular=unitriangular)[0] - - self.assertEqual(x, x_exp.to(device)) - - for upper, transpose, unitriangular in product([True, False], repeat=3): - # test against scipy.linalg.solve_triangular - run_test((2, 1, 3, 4, 4), (2, 1, 3, 4, 6), device, upper, transpose, unitriangular) # no broadcasting - run_test((2, 1, 3, 4, 4), (4, 6), device, upper, transpose, unitriangular) # broadcasting b - run_test((4, 4), (2, 1, 3, 4, 2), device, upper, transpose, unitriangular) # broadcasting A - run_test((1, 3, 1, 4, 4), (2, 1, 3, 4, 5), device, upper, transpose, unitriangular) # broadcasting A & b - - @onlyCPU - @skipCPUIfNoLapack - @dtypes(torch.double) - def test_triangular_solve_singular(self, device, dtype): - b = torch.rand(3, 1, device=device) - A = torch.eye(3, 3, device=device) - A[-1, -1] = 0 # Now A is singular - err_str = r"triangular_solve_cpu: U\(3,3\) is zero, singular U\." - with self.assertRaisesRegex(RuntimeError, err_str): - torch.triangular_solve(b, A) - @skipCUDAIfNoMagma @skipCPUIfNoLapack @dtypes(torch.double) diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index fe313a329a81..afc71cff8304 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -8172,6 +8172,8 @@ def merge_dicts(*dicts): batches of 2D matrices. If the inputs are batches, then returns batched outputs `X` +Supports real-valued and complex-valued inputs. + Args: input (Tensor): multiple right-hand sides of size :math:`(*, m, k)` where :math:`*` is zero of more batch dimensions (:math:`b`) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 3426390256af..5d4b68178416 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -1491,6 +1491,7 @@ def method_tests(): ('__getitem__', torch.randn(S, S, S), (dont_convert([[0, 2, 3], [1, 3, 3], torch.LongTensor([0, 0, 2])]),), 'adv_index_var'), ('to_sparse', (S, S), (), '', (), (), [], lambda x: x.to_dense()), + ('triangular_solve', (S, M), ((S, S), ), '', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), ('kron', (S, S), ((M, L),)) ] From e8a73fbf34b796f270f8d07ea0b5d8643beeb302 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Wed, 11 Nov 2020 16:32:09 -0800 Subject: [PATCH 35/93] Workaround PyTorch debug build crash using old GCC (#47805) Summary: gcc-7.4.x or older fails to compile XNNPACK in debug mode with internal compiler error Workaround this in a build script by pasing -O1 optimisation flag to XNNPACK if compiled on older compilers Fixes https://github.com/pytorch/pytorch/issues/47292 Pull Request resolved: https://github.com/pytorch/pytorch/pull/47805 Reviewed By: seemethere Differential Revision: D24905758 Pulled By: malfet fbshipit-source-id: 93f4e3b3b5c10b69734627c50e36b2eb544699c8 --- cmake/Dependencies.cmake | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 0ce2d8b44a32..742c87b09233 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -510,6 +510,13 @@ if(USE_XNNPACK AND NOT USE_SYSTEM_XNNPACK) "${CONFU_DEPENDENCIES_BINARY_DIR}/XNNPACK") set_property(TARGET XNNPACK PROPERTY POSITION_INDEPENDENT_CODE ON) + # Workaround for https://github.com/pytorch/pytorch/issues/47292 + if(CMAKE_BUILD_TYPE STREQUAL "Debug" AND CMAKE_COMPILER_IS_GNUCXX AND (CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.5.0)) + # Compiling qu8-requantization/precise-psimd.c without any optimization flags on gcc-7.4 or older i + # Fails with internal compiler error + # Workaround by forcing -O1 for XNNPACK (i.e. build it with RelWithDebInfo) + set_property(TARGET XNNPACK APPEND_STRING PROPERTY COMPILE_FLAGS "-O1") + endif() endif() include_directories(SYSTEM ${XNNPACK_INCLUDE_DIR}) From c9f6e70c09a51b643270983a8d259de3ad705700 Mon Sep 17 00:00:00 2001 From: Rohan Varma Date: Wed, 11 Nov 2020 16:50:02 -0800 Subject: [PATCH 36/93] Refactor DDP uneven inputs control flags (#47394) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47394 This is a preliminary refactor for the next diff that will add an additional flag to control whether we throw a StopIteration or not. We basically move the flags for ddp uneven inputs to a simple class. ghstack-source-id: 116428177 Test Plan: CI Reviewed By: pritamdamania87 Differential Revision: D24739509 fbshipit-source-id: 96bf41bd1c02dd27e68f6f37d08e22f33129b319 --- torch/nn/parallel/distributed.py | 33 ++++++++++++------- .../_internal/distributed/distributed_test.py | 3 +- 2 files changed, 24 insertions(+), 12 deletions(-) diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py index 81cdd71e6732..1cb446f1155d 100644 --- a/torch/nn/parallel/distributed.py +++ b/torch/nn/parallel/distributed.py @@ -5,6 +5,7 @@ import inspect import logging import warnings +from typing import NamedTuple import torch @@ -90,6 +91,12 @@ def _dump_DDP_relevant_env_vars(): print(formatted_output) + +class _DDPUnevenInputsConfig(NamedTuple): + ddp_join_enabled: bool + ddp_join_divide_by_initial_world_size: bool + + class DistributedDataParallel(Module): r"""Implements distributed data parallelism that is based on ``torch.distributed`` package at the module level. @@ -151,13 +158,13 @@ class DistributedDataParallel(Module): .. note:: When a model is trained on ``M`` nodes with ``batch=N``, the gradient will be ``M`` times smaller when compared to the same model - trained on a single node with ``batch=M*N`` if the loss is summed (NOT + trained on a single node with ``batch=M*N`` if the loss is summed (NOT averaged as usual) across instances in a batch (because the gradients between different nodes are averaged). You should take this into consideration when you want to obtain a mathematically equivalent - training process compared to the local training counterpart. But in most - cases, you can just treat a DistributedDataParallel wrapped model, a - DataParallel wrapped model and an ordinary model on a single GPU as the + training process compared to the local training counterpart. But in most + cases, you can just treat a DistributedDataParallel wrapped model, a + DataParallel wrapped model and an ordinary model on a single GPU as the same (E.g. using the same learning rate for equivalent batch size). .. note:: @@ -391,7 +398,9 @@ def __init__(self, module, device_ids=None, self.find_unused_parameters = find_unused_parameters self.require_backward_grad_sync = True self.require_forward_param_sync = True - self.ddp_join_enabled = False + self.ddp_uneven_inputs_config = _DDPUnevenInputsConfig( + ddp_join_enabled=False, ddp_join_divide_by_initial_world_size=False + ) self.gradient_as_bucket_view = gradient_as_bucket_view if hasattr(module, '_ddp_params_and_buffers_to_ignore'): self.parameters_to_ignore = module._ddp_params_and_buffers_to_ignore @@ -644,13 +653,13 @@ def no_sync(self): self.require_backward_grad_sync = old_require_backward_grad_sync def forward(self, *inputs, **kwargs): - if self.ddp_join_enabled: + if self.ddp_uneven_inputs_config.ddp_join_enabled: ones = torch.ones( 1, device=self.device ) work = dist.all_reduce(ones, group=self.process_group, async_op=True) self.reducer._set_forward_pass_work_handle( - work, self.ddp_join_divide_by_initial_world_size + work, self.ddp_uneven_inputs_config.ddp_join_divide_by_initial_world_size ) # Calling _rebuild_buckets before forward compuation, @@ -665,7 +674,7 @@ def forward(self, *inputs, **kwargs): if self.require_forward_param_sync: self._sync_params() - if self.ddp_join_enabled: + if self.ddp_uneven_inputs_config.ddp_join_enabled: # Notify joined ranks whether they should sync in backwards pass or not. self._check_global_requires_backward_grad_sync(is_joined_rank=False) @@ -909,8 +918,10 @@ def join(self, divide_by_initial_world_size=True, enable=True): to spawn a single process that works on a single GPU.""" ) has_error = False - self.ddp_join_enabled = enable - self.ddp_join_divide_by_initial_world_size = divide_by_initial_world_size + self.ddp_uneven_inputs_config = _DDPUnevenInputsConfig( + ddp_join_enabled=enable, + ddp_join_divide_by_initial_world_size=divide_by_initial_world_size, + ) yield except Exception as e: # Set to skip any processing in the finally block. @@ -1163,7 +1174,7 @@ def _sync_params(self): # If we are running DDP with the join manager, we have to agree # upon a rank to sync module buffers from, since rank 0 may # already have been joined and have stale module buffers. - if self.ddp_join_enabled: + if self.ddp_uneven_inputs_config.ddp_join_enabled: authoritative_rank = self._find_common_rank(dist.get_rank(), True) else: # The process with rank 0 is considered the authoritative copy. diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index 21fd1a2a88e4..fd14039e1859 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -3736,7 +3736,8 @@ def test_ddp_uneven_input_join_disable(self): net.module.weight.grad.item(), expected_grad ) - self.assertFalse(net.ddp_join_enabled) + join_config = net.ddp_uneven_inputs_config + self.assertFalse(join_config.ddp_join_enabled) self.validate_net_equivalence(net) @require_backend({"gloo", "nccl"}) From c5834b6a23ed6767b0c52641be54efa6f94be48b Mon Sep 17 00:00:00 2001 From: Mehdi Mirzazadeh Date: Wed, 11 Nov 2020 19:05:46 -0800 Subject: [PATCH 37/93] Look in named-buffers of module for tensors (#47641) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47641 ghstack-source-id: 116450114 Test Plan: Presubmit tests Reviewed By: jamesr66a Differential Revision: D24848318 fbshipit-source-id: f6ede3def9d6f1357c4fd3406f97721dea06b9f1 --- torch/fx/graph_module.py | 9 ++++++++- torch/fx/symbolic_trace.py | 4 ++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/torch/fx/graph_module.py b/torch/fx/graph_module.py index 3525e180c43f..c593734eea4c 100644 --- a/torch/fx/graph_module.py +++ b/torch/fx/graph_module.py @@ -82,7 +82,14 @@ def _copy_attr(from_module: torch.nn.Module, to_module: torch.nn.Module, target: setattr(to_module, item, t) from_module, to_module = f, t - setattr(to_module, field, getattr(from_module, field)) + orig = getattr(from_module, field) + # If it is a tensor and not a parameter attribute of a module, it should be a named buffer. + # So, we register it as a named buffer in the target module. + if isinstance(orig, torch.Tensor) and not isinstance(orig, torch.nn.Parameter): + to_module.register_buffer(field, orig) + else: + setattr(to_module, field, orig) + # Assign attribute 'from_obj' to the qualified name 'target' on 'to_module # This installs empty Modules where none exist yet if they are subpaths of target diff --git a/torch/fx/symbolic_trace.py b/torch/fx/symbolic_trace.py index a75d5ff908f8..13ffc2cb0100 100644 --- a/torch/fx/symbolic_trace.py +++ b/torch/fx/symbolic_trace.py @@ -75,6 +75,10 @@ def create_arg(self, a: Any) -> Argument: if a is p: return self.create_node('get_attr', n, (), {}) raise NameError('parameter is not a member of this module') + elif isinstance(a, torch.Tensor): + for n, p in self.root.named_buffers(): + if a is p: + return self.create_node('get_attr', n, (), {}) # Tensors do not have a reliable string repr() from which they can be # constructed (and we probably don't want to rely on that, either), so # for any constant Tensor values we encounter, first search for if they From b46787d6d719e476eab4ff739b6fbf8b35333db7 Mon Sep 17 00:00:00 2001 From: Wang Xu Date: Wed, 11 Nov 2020 19:27:28 -0800 Subject: [PATCH 38/93] add cost_aware_partition (#47673) Summary: [WIP]This PR adds cost_aware_partition method in Partitioner class. The method partitions the fx graph module based on the latency of the whole graph. Pull Request resolved: https://github.com/pytorch/pytorch/pull/47673 Reviewed By: gcatron Differential Revision: D24896685 Pulled By: scottxu0730 fbshipit-source-id: 1b1651fe82ce56554f99d68da116e585c74099ed --- test/test_fx_experimental.py | 61 +++++- torch/fx/experimental/Partitioner.py | 205 ++++++++++++--------- torch/fx/experimental/partitioner_utils.py | 91 ++++++++- 3 files changed, 267 insertions(+), 90 deletions(-) diff --git a/test/test_fx_experimental.py b/test/test_fx_experimental.py index 9a75663e4205..884e21a17ddd 100644 --- a/test/test_fx_experimental.py +++ b/test/test_fx_experimental.py @@ -4,7 +4,7 @@ from torch.fx.graph_module import GraphModule from torch.fx.node import Node from torch.fx.experimental import GraphManipulation -from torch.fx.experimental.Partitioner import Partitioner, Device, PartitionerConfig +from torch.fx.experimental.Partitioner import Partitioner from torch.fx.experimental.rewriter import RewritingTracer from torch.testing._internal.common_utils import run_tests from torch.testing._internal.jit_utils import JitTestCase @@ -12,6 +12,8 @@ NodeLatency, get_partition_to_latency_mapping, get_latency_of_partitioned_graph, + Device, + PartitionerConfig ) from typing import Union, Callable @@ -306,6 +308,63 @@ def get_node_to_latency_mapping(fx_module: GraphModule): ) assert critical_path_latency_sec == 208.0 + def test_cost_aware_partition(self): + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(4, 4) + + def forward(self, a): + add_1 = a + torch.rand(4) + add_2 = add_1 + torch.rand(4) + linear_1 = self.linear(add_1) + add_3 = add_2 + torch.rand(4) + add_4 = add_2 + linear_1 + add_5 = add_3 + add_4 + return add_5 + + def get_node_to_latency_mapping(fx_module: GraphModule): + node_to_latency_mapping: Dict[Node, Nodelatency] = {} + for node in fx_module.graph.nodes: + if node.op not in {'output', 'placeholder', 'get_attr'}: + if node.size_bytes.total_size == node.size_bytes.output_size: + node_to_latency_mapping[node] = NodeLatency(node.size_bytes.total_size, 1) + else: + node_to_latency_mapping[node] = NodeLatency(node.size_bytes.total_size, node.size_bytes.output_size) + return node_to_latency_mapping + + m = MyModule() + traced = symbolic_trace(m) + a = torch.rand(4) + GraphManipulation.get_size_of_all_nodes(traced, [a]) + devices = [ + Device('dev_0', 125, 0), + Device('dev_1', 125, 1), + Device('dev_2', 125, 2), + Device('dev_3', 125, 3) + ] + node_to_latency_mapping = get_node_to_latency_mapping(traced) + partitioner_config = PartitionerConfig( + devices, + is_sparse_nn=False, + is_cost_aware=True, + transfer_rate_bytes_per_sec=0.5, + node_to_latency_mapping=node_to_latency_mapping + ) + partitioner = Partitioner() + ret = partitioner.partition_graph(traced, m, partitioner_config) + module_with_submodules = ret.module_with_submodules + dag = ret.dag + self.assertEqual(traced(a), module_with_submodules(a)) + partitions = partitioner.partitions + partition_to_latency_mapping = get_partition_to_latency_mapping(partitions, node_to_latency_mapping) + critical_path_latency_sec = get_latency_of_partitioned_graph( + partitions, + partition_to_latency_mapping, + partitioner_config.transfer_rate_bytes_per_sec + ) + assert critical_path_latency_sec == 160. + def test_call_to_assert_no_msg(self): class M(torch.nn.Module): def forward(self, a, b): diff --git a/torch/fx/experimental/Partitioner.py b/torch/fx/experimental/Partitioner.py index 9a2c7e3e1881..a11e07cb42c8 100644 --- a/torch/fx/experimental/Partitioner.py +++ b/torch/fx/experimental/Partitioner.py @@ -4,6 +4,9 @@ import torch from torch.fx.experimental.subgraph_creation_example import split_module import operator +from torch.fx.experimental.partitioner_utils import Partition, \ + Device, PartitionerConfig, get_partition_to_latency_mapping,\ + get_latency_of_partitioned_graph, NodeLatency, get_extra_size_of class DAGNode(): """ @@ -43,96 +46,14 @@ def create_node( node = DAGNode(submodule_node, input_nodes, output_nodes, logical_devices, size_bytes) self.nodes.append(node) -class Partition: - """Partition class contains all the information about an individual partition. - It also provides necessary methods for manipulation the partition. - """ - def __init__(self, partition_id: int) -> None: - self.nodes: Set[Node] = set() - self.partition_id = partition_id - self.parents: Set['Partition'] = set() - self.children: Set['Partition'] = set() - self.bfs_level: int = -1 - self.used_mem_bytes: int = 0 - self.logical_device_ids: List[int] = [] - - def __str__(self): - return str(self.partition_id) - - def recalculate_mem_size(self): - self.used_mem_bytes = 0 - for node in self.nodes: - self.used_mem_bytes += get_extra_size_of(node, self.nodes) - - def add_node(self, node): - input_nodes: Dict[Node, None] = {} - map_arg(node.args, lambda n: input_nodes.setdefault(n)) - map_arg(node.kwargs, lambda n: input_nodes.setdefault(n)) - # Add current node's input nodes if they are placeholder or constants - for n in input_nodes: - if n.op in {'placeholder', 'get_attr'}: - self.nodes.add(n) - self.nodes.add(node) - - def remove_node(self, node): - # Remove a node only if the node is in the partition - if node in self.nodes: - self.nodes.remove(node) - # Collect the node's input nodes - input_nodes: Dict[Node, None] = {} - map_arg(node.args, lambda n: input_nodes.setdefault(n)) - map_arg(node.kwargs, lambda n: input_nodes.setdefault(n)) - # Check if an input node is a placeholder or get_attr, - # and this input node is not used by some other nodes in this partition, - # the remove this input node - for input_node in input_nodes: - if all([n not in self.nodes for n in input_node.users]): - self.nodes.remove(input_node) - class PartitionResult(NamedTuple): """NameTuple used for returning DAG and a new graph module """ dag: DAG module_with_submodules: GraphModule -class Device(NamedTuple): - name: str - available_mem_bytes: int - logical_id: int - -class PartitionerConfig(NamedTuple): - devices: List[Device] - is_sparse_nn: bool = False - """Followings are some helper functions for partition manipulation""" -def get_extra_size_of(node: Node, nodes: Set[Node]) -> int: - """Given a node and a set of nodes, - this function return the extra size that needed - if this node is included in this set. - """ - # Find all its input nodes - input_nodes: Dict[Node, None] = {} - map_arg(node.args, lambda n: input_nodes.setdefault(n)) - map_arg(node.kwargs, lambda n: input_nodes.setdefault(n)) - # Calculate total size of related nodes - total_size_of_input_nodes = 0 - for n in input_nodes: - # Make sure this node hasn't been in this set yet - if n not in nodes: - size_bytes = getattr(n, 'size_bytes', None) - if size_bytes: - total_size_of_input_nodes += size_bytes.output_size - else: - raise RuntimeError('node has no size_bytes attr') - # Don't forget the op node itself - size_bytes = getattr(node, 'size_bytes', None) - if size_bytes: - total_size_of_input_nodes += size_bytes.total_size - else: - raise RuntimeError('node has no size_bytes attr') - return total_size_of_input_nodes - def combine_two_partitions( partition_0: Partition, partition_1: Partition, @@ -148,10 +69,7 @@ def combine_two_partitions( partitions.append(partition) partitions.remove(partition_0) partitions.remove(partition_1) - # reset parents and children for all partitions - for partition in partitions: - partition.parents = set() - partition.children = set() + # Reorganize partitions reorganize_partitions(partitions) return @@ -342,6 +260,11 @@ def partition_graph( # sparse_nn_partition only support same memory size # TODO: add different size support for sparse_nn_partition self.sparse_nn_partition(available_mem_bytes) + elif partitioner_config.is_cost_aware: + self.cost_aware_partition( + partitioner_config.transfer_rate_bytes_per_sec, + partitioner_config.node_to_latency_mapping + ) else: self.size_based_partition(available_mem_bytes) module_with_submodules = self.do_partition() @@ -499,7 +422,7 @@ def sparse_nn_partition(self, available_mem_bytes: int) -> None: It first traverse all the nodes and do the partitions based on memory size. If the current partition has no enough memory left for a new op node (call_module, call_method, call_function), a new partition is created. - Different for size_based_partition, when traversing cross the boundary between + Different from size_based_partition, when traversing cross the boundary between non-embedding nodes and embedding nodes, a new partition is created regardlessly. For example, if the current node is a non-embedding node but the next node is an embedding node, a new partition is created for the next node. @@ -649,3 +572,111 @@ def is_embedding_node(node: Node) -> bool: # Get the node to partition mapping self.node_to_partition = get_node_to_partition_mapping(self.partitions) return + + def cost_aware_partition( + self, + transfer_rate_bytes_per_sec: float, + node_to_latency_mapping: Dict[Node, NodeLatency] + ) -> None: + """This method is to partition the fx module based on the cost. + The cost is the total latency of running the whole graph. + In partitioner_utils.py, the cost model is built. + The algorithm is: + #1. At every begining, each node is a partition. + Then we map all the partitions to the devices + and calculate the cost + #2. Then try to pre-combine any two of the partitions if the two + partitions can be combined. + (the bfs level is less than 2 or two partitions are connected and + can find partition to device mapping) + See if any partition pair could reduce the current cost. + Choose the pair that shows the minimum cost and then combine them + #3. Repeat #2 until the cost cannot be reduced. + """ + + def try_combining_partitions( + p0_index, + p1_index, + partitions + ) -> float: + """Given two partitions and a list of partitions, try to combine these two partitions + and see what is the cost of the modified partition list + """ + p0 = partitions[p0_index] + p1 = partitions[p1_index] + """If two partitions' bfs level are less than 2 or two partitions are connected to each other, + then they can be combined + """ + if (abs(p0.bfs_level - p1.bfs_level) <= 1) or (p0 in p1.parents) or p0 in (p1.children): + combine_two_partitions(p0, p1, partitions) + # Check if the modified partition list can be mapped to devices after combination + found_deivce = get_device_to_partitions_mapping(partitions, self.devices) + if not found_deivce: + return float('inf') + # Calculate the new cost + partition_to_latency_mapping = get_partition_to_latency_mapping(partitions, node_to_latency_mapping) + cost = get_latency_of_partitioned_graph(partitions, partition_to_latency_mapping, transfer_rate_bytes_per_sec) + return cost + # If two partition can not be combined, the cost is inf + return float('inf') + + def search_combination( + transfer_rate_bytes_per_sec, + node_to_latency_mapping + ) -> bool: + """Given transfer rate between partitions and each node's latency, + find two partitions to combine so the cost of the partitions can + be reduced. + The algorithm is : + 1. Going through all the partition pairs and see + if the pair of partitions can be combined. + 2. If they are combined, the cost is calculated. + 3. Select the minimum cost and combine its cooresponding partition pair + """ + partition_to_latency_mapping = get_partition_to_latency_mapping(self.partitions, node_to_latency_mapping) + cost = get_latency_of_partitioned_graph(self.partitions, partition_to_latency_mapping, transfer_rate_bytes_per_sec) + if len(self.partitions) == 1: + return False + partition_pair: List[int] = [] + for i in range(len(self.partitions) - 1): + for j in range(i + 1, len(self.partitions)): + # Try to combine the partition pair + # and see the new cost after combination + new_cost = try_combining_partitions( + i, + j, + self.partitions[:] + ) + if new_cost <= cost: + partition_pair = [i, j] + cost = new_cost + # If a partition pair is found, combine them + if len(partition_pair) != 0: + p0 = self.partitions[partition_pair[0]] + p1 = self.partitions[partition_pair[1]] + combine_two_partitions(p0, p1, self.partitions) + get_bfs_level_partition(self.partitions) + get_device_to_partitions_mapping(self.partitions, self.devices) + return True + return False + + for node in self.graph_module.graph.nodes: + if node.op not in {'placeholder', 'get_attr', 'output'}: + self.create_single_node_partition(node) + # Set up parent partitions and children partitions for each partition + set_parents_and_children(self.partitions) + # Get bfs level for each partition + get_bfs_level_partition(self.partitions) + find_combination = True + while find_combination: + # Search for a pair partition to generate the minimum new cost, + # then combine them + find_combination = search_combination( + transfer_rate_bytes_per_sec, + node_to_latency_mapping + ) + # Make sure all partitions are set up correctly. + reorganize_partitions(self.partitions) + # Set up node to partition mapping + self.node_to_partition = get_node_to_partition_mapping(self.partitions) + return diff --git a/torch/fx/experimental/partitioner_utils.py b/torch/fx/experimental/partitioner_utils.py index 23a991b1e44f..0f488e02e912 100644 --- a/torch/fx/experimental/partitioner_utils.py +++ b/torch/fx/experimental/partitioner_utils.py @@ -1,6 +1,55 @@ -from typing import NamedTuple, Dict, List +from typing import NamedTuple, Dict, List, Set from torch.fx.node import Node, map_arg -from torch.fx.experimental.Partitioner import Partition +class Partition: + """Partition class contains all the information about an individual partition. + It also provides necessary methods for manipulation the partition. + """ + def __init__(self, partition_id: int) -> None: + self.nodes: Set[Node] = set() + self.partition_id = partition_id + self.parents: Set['Partition'] = set() + self.children: Set['Partition'] = set() + self.bfs_level: int = -1 + self.used_mem_bytes: int = 0 + self.logical_device_ids: List[int] = [] + + def __str__(self): + return str(self.partition_id) + + def recalculate_mem_size(self): + self.used_mem_bytes = 0 + for node in self.nodes: + self.used_mem_bytes += get_extra_size_of(node, self.nodes) + + def add_node(self, node): + input_nodes: Dict[Node, None] = {} + map_arg(node.args, lambda n: input_nodes.setdefault(n)) + map_arg(node.kwargs, lambda n: input_nodes.setdefault(n)) + # Add current node's input nodes if they are placeholder or constants + for n in input_nodes: + if n.op in {'placeholder', 'get_attr'}: + self.nodes.add(n) + self.nodes.add(node) + + def remove_node(self, node): + # Remove a node only if the node is in the partition + if node in self.nodes: + self.nodes.remove(node) + # Collect the node's input nodes + input_nodes: Dict[Node, None] = {} + map_arg(node.args, lambda n: input_nodes.setdefault(n)) + map_arg(node.kwargs, lambda n: input_nodes.setdefault(n)) + # Check if an input node is a placeholder or get_attr, + # and this input node is not used by some other nodes in this partition, + # the remove this input node + for input_node in input_nodes: + if all([n not in self.nodes for n in input_node.users]): + self.nodes.remove(input_node) + +class Device(NamedTuple): + name: str + available_mem_bytes: int + logical_id: int class NodeLatency(NamedTuple): # Latency due to the memory bandwidth @@ -16,6 +65,40 @@ class PartitionLatency(NamedTuple): # Latency of the critical path overall_latency_sec: float +class PartitionerConfig(NamedTuple): + devices: List[Device] + is_sparse_nn: bool = False + is_cost_aware: bool = False + transfer_rate_bytes_per_sec: float = 0. + node_to_latency_mapping: Dict[Node, NodeLatency] = {} + +def get_extra_size_of(node: Node, nodes: Set[Node]) -> int: + """Given a node and a set of nodes, + this function return the extra size that needed + if this node is included in this set. + """ + # Find all its input nodes + input_nodes: Dict[Node, None] = {} + map_arg(node.args, lambda n: input_nodes.setdefault(n)) + map_arg(node.kwargs, lambda n: input_nodes.setdefault(n)) + # Calculate total size of related nodes + total_size_of_input_nodes = 0 + for n in input_nodes: + # Make sure this node hasn't been in this set yet + if n not in nodes: + size_bytes = getattr(n, 'size_bytes', None) + if size_bytes: + total_size_of_input_nodes += size_bytes.output_size + else: + raise RuntimeError('node has no size_bytes attr') + # Don't forget the op node itself + size_bytes = getattr(node, 'size_bytes', None) + if size_bytes: + total_size_of_input_nodes += size_bytes.total_size + else: + raise RuntimeError('node has no size_bytes attr') + return total_size_of_input_nodes + def get_latency_of_one_partition( partition: Partition, node_to_latency_mapping: Dict[Node, NodeLatency] @@ -92,6 +175,10 @@ def get_comm_latency_between(parent_partition: Partition, child_partition: Parti """Given two partitions (parent and child), calculate the communication latency between the two. """ + # If two partitions are on the same device, the comm latency is 0. + if parent_partition.logical_device_ids != [] and child_partition.logical_device_ids != [] \ + and parent_partition.logical_device_ids == child_partition.logical_device_ids: + return 0. # Keep tracking the communication size between parent and child comm_size = 0 # Keep tracking all the counted node From dd77d5a1d4cf175d46de4c3630d8c1ac36fdac67 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 11 Nov 2020 20:56:56 -0800 Subject: [PATCH 39/93] [quant][refactor] factor out get_combined_dict function (#47781) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47781 Test Plan: Imported from OSS Reviewed By: supriyar Differential Revision: D24900303 fbshipit-source-id: 1a2cb0ec536384abcd140e0d073f0965ed2800cd --- torch/quantization/fuser_method_mappings.py | 6 +++--- torch/quantization/fx/fuse.py | 8 +++++--- torch/quantization/fx/quantize.py | 13 +++++++------ torch/quantization/quantization_mappings.py | 6 ++---- torch/quantization/utils.py | 9 +++++++++ 5 files changed, 26 insertions(+), 16 deletions(-) create mode 100644 torch/quantization/utils.py diff --git a/torch/quantization/fuser_method_mappings.py b/torch/quantization/fuser_method_mappings.py index 0b72f5485231..a20d9c6ad682 100644 --- a/torch/quantization/fuser_method_mappings.py +++ b/torch/quantization/fuser_method_mappings.py @@ -3,6 +3,8 @@ from typing import Union, Callable, Tuple, Dict, Optional, Type +from .utils import get_combined_dict + def fuse_conv_bn(conv, bn): r"""Given the conv and bn modules, fuses them and returns the fused module @@ -101,9 +103,7 @@ def get_fuser_method(op_list, additional_fuser_method_mapping=None): ''' if additional_fuser_method_mapping is None: additional_fuser_method_mapping = {} - all_mappings = DEFAULT_OP_LIST_TO_FUSER_METHOD.copy() - for k, v in additional_fuser_method_mapping: - all_mappings[k] = v + all_mappings = get_combined_dict(DEFAULT_OP_LIST_TO_FUSER_METHOD, additional_fuser_method_mapping) fuser_method = all_mappings.get(op_list, None) assert fuser_method is not None, "did not find fuser method for: {} ".format(op_list) return fuser_method diff --git a/torch/quantization/fx/fuse.py b/torch/quantization/fx/fuse.py index 56b375a02c00..5477dc39999b 100644 --- a/torch/quantization/fx/fuse.py +++ b/torch/quantization/fx/fuse.py @@ -5,6 +5,10 @@ from torch.fx.graph import Graph +from ..utils import ( + get_combined_dict +) + from .pattern_utils import ( is_match, get_default_fusion_patterns, @@ -22,9 +26,7 @@ def fuse(self, model, fuse_custom_config_dict=None): self.modules = dict(input_root.named_modules()) additional_fusion_patterns = fuse_custom_config_dict.get("additional_quant_pattern", {}) - fusion_patterns = get_default_fusion_patterns().copy() - for k, v in additional_fusion_patterns.items(): - fusion_patterns[k] = v + fusion_patterns = get_combined_dict(get_default_fusion_patterns(), additional_fusion_patterns) # find fusion fusion_pairs = self._find_matches(input_root, input_graph, fusion_patterns) self.fused_graph = Graph() diff --git a/torch/quantization/fx/quantize.py b/torch/quantization/fx/quantize.py index b87612c97dbe..9964fde074f9 100644 --- a/torch/quantization/fx/quantize.py +++ b/torch/quantization/fx/quantize.py @@ -24,6 +24,10 @@ is_activation_post_process ) +from ..utils import ( + get_combined_dict +) + from .pattern_utils import ( is_match, get_default_quant_patterns, @@ -233,9 +237,7 @@ def __init__(self): def _qat_swap_modules(self, root, additional_qat_module_mapping): - all_mappings = get_default_qat_module_mappings().copy() - for k, v in additional_qat_module_mapping.items(): - all_mappings[k] = v + all_mappings = get_combined_dict(get_default_qat_module_mappings(), additional_qat_module_mapping) convert(root, mapping=all_mappings, inplace=True, remove_qconfig=False) def _generate_qconfig_map(self, @@ -327,10 +329,9 @@ def _prepare(self, model, qconfig_dict, prepare_custom_config_dict, is_standalon """ if prepare_custom_config_dict is None: prepare_custom_config_dict = {} + additional_quant_patterns = prepare_custom_config_dict.get("additional_quant_pattern", {}) - self.patterns = get_default_quant_patterns().copy() - for k, v in additional_quant_patterns.items(): - self.patterns[k] = v + self.patterns = get_combined_dict(get_default_quant_patterns(), additional_quant_patterns) flattened_qconfig_dict = get_flattened_qconfig_dict(qconfig_dict) # TODO: support regex as well diff --git a/torch/quantization/quantization_mappings.py b/torch/quantization/quantization_mappings.py index c163e44d3414..0aa5cb845e12 100644 --- a/torch/quantization/quantization_mappings.py +++ b/torch/quantization/quantization_mappings.py @@ -14,6 +14,7 @@ default_affine_fixed_qparams_fake_quant, default_symmetric_fixed_qparams_fake_quant, ) +from .utils import get_combined_dict # Default map for swapping float module to quantized ones DEFAULT_STATIC_QUANT_MODULE_MAPPINGS = { @@ -117,10 +118,7 @@ def get_static_quant_module_class(float_module_class, additional_static_quant_ma """ if additional_static_quant_mapping is None: additional_static_quant_mapping = {} - all_mappings = DEFAULT_STATIC_QUANT_MODULE_MAPPINGS.copy() - for k, v in additional_static_quant_mapping.items(): - all_mappings[k] = v - + all_mappings = get_combined_dict(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS, additional_static_quant_mapping) static_quant_module_class = all_mappings.get(float_module_class, None) assert static_quant_module_class is not None, \ "Floating point module class {}".format(str(float_module_class)) + \ diff --git a/torch/quantization/utils.py b/torch/quantization/utils.py new file mode 100644 index 000000000000..956bf7490eaf --- /dev/null +++ b/torch/quantization/utils.py @@ -0,0 +1,9 @@ +""" +Utils shared by different modes of quantization (eager/graph) +""" + +def get_combined_dict(default_dict, additional_dict): + d = default_dict.copy() + for k, v in additional_dict.items(): + d[k] = v + return d From 47386722da66d5d7f12322c500474794756ea8d7 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 11 Nov 2020 21:28:53 -0800 Subject: [PATCH 40/93] [quant][graphmode][fx][refactor] insert_observer (#47782) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47782 Test Plan: python test/test_quantization.py TestQuantizeFx Imported from OSS Reviewed By: supriyar Differential Revision: D24900305 fbshipit-source-id: b00a90ab85badea7d18ae007cc68d0bcd58ab15c --- test/quantization/test_quantize_fx.py | 1 - torch/quantization/fx/quantize.py | 58 ++++++++++++--------------- 2 files changed, 26 insertions(+), 33 deletions(-) diff --git a/test/quantization/test_quantize_fx.py b/test/quantization/test_quantize_fx.py index 16694b0f0356..af9827d00307 100644 --- a/test/quantization/test_quantize_fx.py +++ b/test/quantization/test_quantize_fx.py @@ -1028,7 +1028,6 @@ def forward(self, x): return x m = M() - print(m.__dict__.keys()) m.eval() qconfig_dict = {'': torch.quantization.default_qconfig} prepared = prepare_fx(m, qconfig_dict) diff --git a/torch/quantization/fx/quantize.py b/torch/quantization/fx/quantize.py index 9964fde074f9..8c131de22fd1 100644 --- a/torch/quantization/fx/quantize.py +++ b/torch/quantization/fx/quantize.py @@ -374,6 +374,28 @@ def load_arg(a): graph_inputs.append(node.name) get_new_observer_name = get_new_attr_name_with_prefix('activation_post_process_') + model_device = assert_and_get_unique_device(model) + + def insert_observer(node, observer): + """Insert observer for node by modifying the observed_graph and + attach observer module to the model + Args: + node: Node + observer: observer/fake_quantize module instance + """ + # respect device affinity when adding observers + if model_device: + observer.to(model_device) + # add observer module as attribute + prefix = node.name + '_activation_post_process_' + get_new_observer_name = get_new_attr_name_with_prefix(prefix) + observer_name = get_new_observer_name(model) + setattr(model, observer_name, observer) + # put observer instance activation_post_process map + self.activation_post_process_map[node.name] = observer + # insert observer call + env[node.name] = observed_graph.create_node('call_module', observer_name, (load_arg(node),), {}) + observed_node_names_set.add(node.name) result_node : Optional[Node] = None for node in model.graph.nodes: @@ -384,23 +406,11 @@ def load_arg(a): if node.name in observed_node_names_set: continue - prefix = node.name + '_activation_post_process_' root_node, matched_nodes, pattern, obj, qconfig = matches.get(node.name, (None, None, None, None, None)) if root_node is None: env[node.name] = observed_graph.node_copy(node, load_arg) elif root_node is node: env[node.name] = observed_graph.node_copy(node, load_arg) - - def insert_observer(node, observer, device): - get_new_observer_name = get_new_attr_name_with_prefix(prefix) - observer_name = get_new_observer_name(model) - setattr(model, observer_name, observer) - self.activation_post_process_map[node.name] = observer - env[node.name] = observed_graph.create_node('call_module', observer_name, (load_arg(node),), {}) - observed_node_names_set.add(node.name) - if device: - getattr(model, observer_name).to(device) - # index for input of custom module that needs to be observed in parent standalone_module_input_idxs = None if qconfig is not None: @@ -437,8 +447,7 @@ def insert_observer(node, observer, device): assert activation_post_process_ctr is not None, \ "activation_post_process constructor not provided for " + \ "pattern:" + str(pattern) - device = assert_and_get_unique_device(model) - insert_observer(node, activation_post_process_ctr(), device) + insert_observer(node, activation_post_process_ctr()) elif (isinstance(obj, FixedQParamsOpQuantizeHandler) and not model.training) or isinstance(obj, CopyNode): # inserting observers for output of observed module, or mark the output @@ -476,17 +485,14 @@ def input_is_observed(arg): elif obj.all_node_args: # observer for outputs new_observer = qconfig.activation() - # respect device affinity when adding observers - device = assert_and_get_unique_device(model) - insert_observer(node, new_observer, device) + insert_observer(node, new_observer) # insert observer for input of standalone module if standalone_module_input_idxs is not None: for idx in standalone_module_input_idxs: if node.args[idx].name not in observed_node_names_set: new_observer = qconfig.activation() - device = assert_and_get_unique_device(model) - insert_observer(node.args[idx], new_observer, device) + insert_observer(node.args[idx], new_observer) else: env[node.name] = observed_graph.node_copy(node, load_arg) @@ -497,21 +503,9 @@ def input_is_observed(arg): # in parent graph standalone_module_observed_input_idxs.append(graph_inputs.index(node.name)) continue - get_new_observer_name = get_new_attr_name_with_prefix(prefix) - observer_name = get_new_observer_name(model) _, activation_post_process_ctr = quants[node.name] if activation_post_process_ctr is not None: - # TODO: use insert_observer - new_observer = activation_post_process_ctr() - - # respect device affinity when adding observers - device = assert_and_get_unique_device(model) - if device: - new_observer.to(device) - self.activation_post_process_map[node.name] = new_observer - setattr(model, observer_name, self.activation_post_process_map[node.name]) - env[node.name] = observed_graph.create_node('call_module', observer_name, (load_arg(node),), {}) - observed_node_names_set.add(node.name) + insert_observer(node, activation_post_process_ctr()) model = GraphModule(model, observed_graph) self.save_state(model) From 89b371bc281c7b42a775f1ff3eb981abf28fe43e Mon Sep 17 00:00:00 2001 From: Supriya Rao Date: Wed, 11 Nov 2020 22:42:09 -0800 Subject: [PATCH 41/93] [quant] Add support for 2D indices for quantized embedding operators (#47766) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47766 The operator now supports accepting 2D indices as inputs. For embedding operators, we set the default offsets in the op since the FBGEMM kernel expects it to be set Output shape depends on the shape if the indices. For embedding_bag operator, if indices is 2D (B, N) then offsets should be set to None by user. In this case the input is interpreted as B bags each of fixed length N. Output shape is still 2-D in this case. Test Plan: python test/test_quantization.py TestQuantizedEmbeddingOps.test_embedding_bag_2d_indices python test/test_quantization.py TestQuantizedEmbeddingOps.test_embedding_2d_indices Imported from OSS Reviewed By: jerryzh168 Differential Revision: D24895048 fbshipit-source-id: 2020910e1d85ed8673eedee2e504611ba260d801 --- .../quantized/cpu/embedding_packed_params.h | 3 +- .../ATen/native/quantized/cpu/fbgemm_utils.h | 3 +- .../native/quantized/cpu/qembeddingbag.cpp | 99 ++++++++++++++----- test/quantization/test_quantized_op.py | 63 ++++++++++++ 4 files changed, 143 insertions(+), 25 deletions(-) diff --git a/aten/src/ATen/native/quantized/cpu/embedding_packed_params.h b/aten/src/ATen/native/quantized/cpu/embedding_packed_params.h index cf98bc56bb82..d2e7500bf302 100644 --- a/aten/src/ATen/native/quantized/cpu/embedding_packed_params.h +++ b/aten/src/ATen/native/quantized/cpu/embedding_packed_params.h @@ -10,7 +10,8 @@ struct EmbeddingPackedParamsBase : public torch::jit::CustomClassHolder { bool pruned_weights, const c10::optional& per_sample_weights_, const c10::optional& compressed_indices_mapping, - bool include_last_offset) = 0; + bool include_last_offset, + bool is_embedding_op) = 0; virtual at::Tensor embeddingbag_4bit( const at::Tensor& indices, diff --git a/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h b/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h index 765e93beca36..a2349790d117 100644 --- a/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h +++ b/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h @@ -345,7 +345,8 @@ struct CAFFE2_API PackedEmbeddingBagWeight : public EmbeddingPackedParamsBase { bool pruned_weights, const c10::optional& per_sample_weights_, const c10::optional& compressed_indices_mapping, - bool include_last_offset) override; + bool include_last_offset, + bool is_embedding_op) override; at::Tensor embeddingbag_4bit( const at::Tensor& indices, diff --git a/aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp b/aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp index 8aa16fb0f6cf..13f98cd3a494 100644 --- a/aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp +++ b/aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp @@ -22,7 +22,6 @@ at::Tensor embedding_bag_4bit_impl( const c10::optional& compressed_indices_mapping, bool include_last_offset) { TORCH_CHECK(weight.dim() == 2); - TORCH_CHECK(indices.dim() == 1); TORCH_CHECK(offsets.dim() == 1); const auto weight_data = weight.data_ptr(); @@ -198,12 +197,11 @@ at::Tensor embedding_bag_byte_impl( bool pruned_weights, const c10::optional& per_sample_weights_, const c10::optional& compressed_indices_mapping, - bool include_last_offset) { + bool include_last_offset, + bool is_embedding_op) { TORCH_CHECK(weight.scalar_type() == at::kByte); TORCH_CHECK(weight.dim() == 2); - TORCH_CHECK(indices.dim() == 1); TORCH_CHECK(offsets.dim() == 1); - const auto weight_data = weight.data_ptr(); const auto indices_data = indices.data_ptr(); auto offsets_data = offsets.data_ptr(); @@ -223,6 +221,7 @@ at::Tensor embedding_bag_byte_impl( int64_t output_size = M - 1; std::vector offsets_include_last_val; + if (!include_last_offset) { output_size = M; offsets_include_last_val.resize(M + 1); @@ -237,8 +236,12 @@ at::Tensor embedding_bag_byte_impl( offsets_include_last_val[M] = indices.numel(); offsets_data = offsets_include_last_val.data(); } - - std::vector shape = {output_size, D}; + std::vector shape; + if (indices.dim() == 2 && is_embedding_op) { + shape = {indices.size(0), indices.size(1), D}; + } else { + shape = {output_size, D}; + } auto output = at::empty(shape, weight.options().dtype(at::kFloat)); auto* output_data = output.data_ptr(); @@ -314,11 +317,28 @@ at::Tensor embedding_bag_byte_helper( bool pruned_weights, const c10::optional& per_sample_weights_, const c10::optional& compressed_indices_mapping, - bool include_last_offset) { + bool include_last_offset, + bool is_embedding_op) { + at::Tensor offsets; TORCH_CHECK( - offsets_in.has_value(), - "embedding_bag_4bit_rowwise_offsets expects offsets to be set"); - auto offsets = offsets_in.value(); + indices.dim() == 1 || indices.dim() == 2, + "qembedding/qembedding_bag operator supports 1 or 2d indices, got ", + indices.dim()); + // For embedding_bag operator with 2D indices, we set the offsets explicitly + // here. + if (indices.dim() == 2 && !is_embedding_op) { + TORCH_CHECK( + !offsets_in.has_value(), + "embedding_bag_byte operator: input is 2D, then offsets has to be None, as input is treated is a mini-batch of fixed length sequences."); + + offsets = + at::arange(0, indices.numel(), indices.size(1), indices.scalar_type()); + } else { + TORCH_CHECK( + offsets_in.has_value(), + "embedding_bag_byte expects offsets to be set for 1D indices."); + offsets = offsets_in.value(); + } TORCH_CHECK( indices.scalar_type() == at::kInt || indices.scalar_type() == at::kLong, @@ -341,7 +361,8 @@ at::Tensor embedding_bag_byte_helper( pruned_weights, per_sample_weights_, compressed_indices_mapping, - include_last_offset); + include_last_offset, + is_embedding_op); } else if ( indices.scalar_type() == at::kInt && offsets.scalar_type() == at::kLong) { return embedding_bag_byte_impl( @@ -351,7 +372,8 @@ at::Tensor embedding_bag_byte_helper( pruned_weights, per_sample_weights_, compressed_indices_mapping, - include_last_offset); + include_last_offset, + is_embedding_op); } else if ( indices.scalar_type() == at::kLong && offsets.scalar_type() == at::kInt) { return embedding_bag_byte_impl( @@ -361,7 +383,8 @@ at::Tensor embedding_bag_byte_helper( pruned_weights, per_sample_weights_, compressed_indices_mapping, - include_last_offset); + include_last_offset, + is_embedding_op); } // default case given the TORCH_CHECK above @@ -372,7 +395,8 @@ at::Tensor embedding_bag_byte_helper( pruned_weights, per_sample_weights_, compressed_indices_mapping, - include_last_offset); + include_last_offset, + is_embedding_op); } at::Tensor embedding_bag_4bit_helper( @@ -383,10 +407,27 @@ at::Tensor embedding_bag_4bit_helper( const c10::optional& per_sample_weights_, const c10::optional& compressed_indices_mapping, bool include_last_offset) { + at::Tensor offsets; TORCH_CHECK( - offsets_in.has_value(), - "embedding_bag_4bit_rowwise_offsets expects offsets to be set"); - auto offsets = offsets_in.value(); + indices.dim() == 1 || indices.dim() == 2, + "qembedding/qembedding_bag operator supports 1 or 2d indices, got ", + indices.dim()); + + // For embedding_bag operator with 2D indices, we need to set the offsets + // explicitly here. + if (indices.dim() == 2) { + TORCH_CHECK( + !offsets_in.has_value(), + "embedding_bag_4bit operator: input is 2D, then offsets has to be None, as input is treated is a mini-batch of fixed length sequences."); + + offsets = + at::arange(0, indices.numel(), indices.size(1), indices.scalar_type()); + } else { + TORCH_CHECK( + offsets_in.has_value(), + "embedding_bag_4bit operator expects offsets to be set for 1D indices."); + offsets = offsets_in.value(); + } TORCH_CHECK( indices.scalar_type() == at::kInt || indices.scalar_type() == at::kLong, @@ -448,7 +489,8 @@ at::Tensor PackedEmbeddingBagWeight::embeddingbag_byte( bool pruned_weights, const c10::optional& per_sample_weights_, const c10::optional& compressed_indices_mapping, - bool include_last_offset) { + bool include_last_offset, + bool is_embedding_op) { return embedding_bag_byte_helper( packed_w.contiguous(), indices, @@ -456,7 +498,8 @@ at::Tensor PackedEmbeddingBagWeight::embeddingbag_byte( pruned_weights, per_sample_weights_, compressed_indices_mapping, - include_last_offset); + include_last_offset, + is_embedding_op); } at::Tensor PackedEmbeddingBagWeight::embeddingbag_4bit( @@ -497,7 +540,8 @@ Tensor embedding_bag_byte_rowwise_offsets( pruned_weights, per_sample_weights_, compressed_indices_mapping, - include_last_offset); + include_last_offset, + false /* is_embedding_op */); } Tensor embedding_bag_4bit_rowwise_offsets( @@ -540,7 +584,8 @@ class QEmbeddingBag final { pruned_weights, per_sample_weights_, compressed_indices_mapping, - include_last_offset); + include_last_offset, + false /* is_embedding_op */); } else if (bit_rate == 4) { return packed_weight->embeddingbag_4bit( indices, @@ -563,12 +608,20 @@ class QEmbedding final { const c10::intrusive_ptr& packed_weight, const Tensor& indices, bool pruned_weights) { + // Set default offsets here since the FBGEMM lookup op expects it. const auto offsets_size = indices.numel(); - at::Tensor offsets = at::arange(0, offsets_size, at::kLong); + at::Tensor offsets = at::arange(0, offsets_size, indices.scalar_type()); at::Tensor output; if (bit_rate == 8) { return packed_weight->embeddingbag_byte( - indices, offsets, pruned_weights, c10::nullopt, c10::nullopt, false); + indices, + offsets, + pruned_weights, + c10::nullopt, + c10::nullopt, + false /* include_last_offset */, + true /* is_embedding_op */); + } else { TORCH_INTERNAL_ASSERT( "Currently only support 8-bit embedding quantization"); diff --git a/test/quantization/test_quantized_op.py b/test/quantization/test_quantized_op.py index fcc4e689e5e5..ee6a757a5c9e 100644 --- a/test/quantization/test_quantized_op.py +++ b/test/quantization/test_quantized_op.py @@ -2964,6 +2964,7 @@ def test_embedding_bag_2bit_unpack(self, num_embeddings, embedding_dim, optimize self._test_embedding_bag_unpack_fn(pack_fn, unpack_fn, num_embeddings, embedding_dim, 2, optimized_qparams) + def embedding_bag_rowwise_offsets_run( self, bit_rate, num_embeddings, embedding_dim, num_offsets, @@ -3161,6 +3162,68 @@ def test_embedding_byte(self, num_embeddings, embedding_dim): ref = torch.embedding(weights, indices, padding_idx=-1, scale_grad_by_freq=False, sparse=False) torch.testing.assert_allclose(ref, qresult, atol=0.005, rtol=1e-3) + + @skipIfNoFBGEMM + def test_embedding_2d_indices(self): + """ + Tests the case where 2D indices are passed into the operator + In this case the operator computes the correct offsets argument. + Output shape is dependent on the indices dimension. + """ + quant_op = torch.ops.quantized.embedding_byte + prepack_op = torch.ops.quantized.embedding_bag_prepack + + indices = torch.tensor([[9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8], [3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3]]) + weights = torch.randn(10, 12, dtype=torch.float32) + + ref = torch.embedding(weights, indices, padding_idx=-1, scale_grad_by_freq=False, sparse=False) + obs = PerChannelMinMaxObserver(dtype=torch.quint8, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0) + obs(weights) + qparams = obs.calculate_qparams() + + qweight = torch.quantize_per_channel(weights, qparams[0], qparams[1], axis=0, dtype=torch.quint8) + packed_weight = prepack_op(qweight) + qresult = quant_op(packed_weight, indices, pruned_weights=False) + torch.testing.assert_allclose(ref, qresult, atol=0.05, rtol=1e-3) + + @skipIfNoFBGEMM + def test_embedding_bag_2d_indices(self): + """ + Tests the case where 2D indices are passed into the operator + In this case the operator computes the correct offsets argument. + """ + indices = torch.tensor([[9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8], [3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3]]) + weights = torch.randn(10, 12, dtype=torch.float32) + + embedding_bag = torch.nn.EmbeddingBag( + num_embeddings=10, + embedding_dim=12, + include_last_offset=False, _weight=weights, + scale_grad_by_freq=False, mode='sum' + ) + result = embedding_bag(indices) + + pt_op = torch.ops.quantized.embedding_bag_byte_rowwise_offsets + pt_prepack_op = torch.ops.quantized.embedding_bag_byte_prepack + q_weights = pt_prepack_op(weights) + qresult = pt_op(q_weights, indices, mode=0, pruned_weights=False) + torch.testing.assert_allclose(result, qresult, atol=0.05, rtol=1e-3) + + # Test TorchBind based embedding_bag operator + obs = PerChannelMinMaxObserver(dtype=torch.quint8, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0) + obs(weights) + # Get the scale and zero point for the weight tensor + qparams = obs.calculate_qparams() + + # Quantize the weights to 8bits + qweight = torch.quantize_per_channel(weights, qparams[0], qparams[1], axis=0, dtype=torch.quint8) + + packed_weight = torch.ops.quantized.embedding_bag_prepack(qweight) + qresult = torch.ops.quantized.embedding_bag_byte(packed_weight, indices, mode=0) + + torch.testing.assert_allclose(result, qresult, atol=0.05, rtol=1e-3) + + class TestQuantizedConv(TestCase): def _test_qconv_unpack_impl(self, qconv_prepack_fn, qconv_unpack_fn, inputs, strides, i_pads, o_pads, channelwise): From 70ae5685f96149c84ff8280ca93720bcf8085f83 Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Wed, 11 Nov 2020 22:49:06 -0800 Subject: [PATCH 42/93] [reland][c10d] switch ProcessGroup::Work to be managed by intrusive_ptr (#47806) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47806 reland https://github.com/pytorch/pytorch/pull/44046 Test Plan: wait for ci Reviewed By: gmagogsfm Differential Revision: D24905245 fbshipit-source-id: ad75ace5432fcfd22d513878f5a73c4bb017324e --- test/cpp_extensions/cpp_c10d_extension.cpp | 32 +++--- test/cpp_extensions/cpp_c10d_extension.hpp | 26 ++--- torch/csrc/distributed/c10d/init.cpp | 5 +- .../distributed/rpc/process_group_agent.cpp | 2 +- .../distributed/rpc/process_group_agent.h | 4 +- torch/lib/c10d/ProcessGroup.cpp | 2 +- torch/lib/c10d/ProcessGroup.hpp | 35 +++--- torch/lib/c10d/ProcessGroupGloo.cpp | 101 +++++++++--------- torch/lib/c10d/ProcessGroupGloo.hpp | 38 +++---- torch/lib/c10d/ProcessGroupMPI.cpp | 42 ++++---- torch/lib/c10d/ProcessGroupMPI.hpp | 36 +++---- torch/lib/c10d/ProcessGroupNCCL.cpp | 52 ++++----- torch/lib/c10d/ProcessGroupNCCL.hpp | 46 ++++---- torch/lib/c10d/ProcessGroupRoundRobin.cpp | 30 +++--- torch/lib/c10d/ProcessGroupRoundRobin.hpp | 30 +++--- torch/lib/c10d/comm.cpp | 4 +- torch/lib/c10d/example/allreduce.cpp | 2 +- torch/lib/c10d/reducer.cpp | 2 +- torch/lib/c10d/reducer.hpp | 9 +- .../c10d/test/ProcessGroupGlooAsyncTest.cpp | 10 +- torch/lib/c10d/test/ProcessGroupGlooTest.cpp | 16 +-- torch/lib/c10d/test/ProcessGroupMPITest.cpp | 38 +++---- .../c10d/test/ProcessGroupNCCLErrorsTest.cpp | 8 +- torch/lib/c10d/test/ProcessGroupNCCLTest.cpp | 12 +-- 24 files changed, 295 insertions(+), 287 deletions(-) diff --git a/test/cpp_extensions/cpp_c10d_extension.cpp b/test/cpp_extensions/cpp_c10d_extension.cpp index b4901cdbcf4d..50e5f5861caa 100644 --- a/test/cpp_extensions/cpp_c10d_extension.cpp +++ b/test/cpp_extensions/cpp_c10d_extension.cpp @@ -23,85 +23,85 @@ ProcessGroupTest::ProcessGroupTest(int rank, int size) ProcessGroupTest::~ProcessGroupTest() {} -std::shared_ptr ProcessGroupTest::broadcast( +c10::intrusive_ptr ProcessGroupTest::broadcast( std::vector& tensors, const BroadcastOptions& opts) { - return std::make_shared(); + return c10::make_intrusive(); } -std::shared_ptr ProcessGroupTest::allreduce( +c10::intrusive_ptr ProcessGroupTest::allreduce( std::vector& tensors, const AllreduceOptions& opts) { - return std::make_shared(); + return c10::make_intrusive(); } -std::shared_ptr ProcessGroupTest::allreduce_coalesced( +c10::intrusive_ptr ProcessGroupTest::allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts) { throw std::runtime_error("ProcessGroupTest does not support allreduce_coalesced"); } -std::shared_ptr ProcessGroupTest::reduce( +c10::intrusive_ptr ProcessGroupTest::reduce( std::vector& tensors, const ReduceOptions& opts) { throw std::runtime_error("ProcessGroupTest does not support reduce"); } -std::shared_ptr ProcessGroupTest::allgather( +c10::intrusive_ptr ProcessGroupTest::allgather( std::vector>& outputTensors, std::vector& inputTensors, const AllgatherOptions& opts) { throw std::runtime_error("ProcessGroupTest does not support allgather"); } -std::shared_ptr ProcessGroupTest::allgather_base( +c10::intrusive_ptr ProcessGroupTest::allgather_base( at::Tensor& outputBuffer, at::Tensor& inputBuffer, const AllgatherOptions& opts) { throw std::runtime_error("ProcessGroupTest does not support allgather_base"); } -std::shared_ptr ProcessGroupTest::barrier( +c10::intrusive_ptr ProcessGroupTest::barrier( const BarrierOptions& opts) { - return std::make_shared(); + return c10::make_intrusive(); } -std::shared_ptr ProcessGroupTest::gather( +c10::intrusive_ptr ProcessGroupTest::gather( std::vector>& outputTensors, std::vector& inputTensors, const GatherOptions& opts) { throw std::runtime_error("ProcessGroupTest does not support gather"); } -std::shared_ptr ProcessGroupTest::scatter( +c10::intrusive_ptr ProcessGroupTest::scatter( std::vector& outputTensors, std::vector>& inputTensors, const ScatterOptions& opts) { throw std::runtime_error("ProcessGroupTest does not support scatter"); } -std::shared_ptr ProcessGroupTest::reduce_scatter( +c10::intrusive_ptr ProcessGroupTest::reduce_scatter( std::vector& outputTensors, std::vector>& inputTensors, const ReduceScatterOptions& opts) { throw std::runtime_error("ProcessGroupTest does not support reduce_scatter"); } -std::shared_ptr ProcessGroupTest::send( +c10::intrusive_ptr ProcessGroupTest::send( std::vector& tensors, int dstRank, int tag) { throw std::runtime_error("ProcessGroupTest does not support send"); } -std::shared_ptr ProcessGroupTest::recv( +c10::intrusive_ptr ProcessGroupTest::recv( std::vector& tensors, int srcRank, int tag) { throw std::runtime_error("ProcessGroupTest does not support recv"); } -std::shared_ptr ProcessGroupTest::recvAnysource( +c10::intrusive_ptr ProcessGroupTest::recvAnysource( std::vector& tensor, int tag) { throw std::runtime_error("ProcessGroupTest does not support recvAnysource"); diff --git a/test/cpp_extensions/cpp_c10d_extension.hpp b/test/cpp_extensions/cpp_c10d_extension.hpp index d8dffcd20327..8aeec736d440 100644 --- a/test/cpp_extensions/cpp_c10d_extension.hpp +++ b/test/cpp_extensions/cpp_c10d_extension.hpp @@ -41,61 +41,61 @@ class ProcessGroupTest : public ProcessGroup { explicit ProcessGroupTest(int rank = -1, int size = -1); virtual ~ProcessGroupTest(); - std::shared_ptr broadcast( + c10::intrusive_ptr broadcast( std::vector& data, const BroadcastOptions& opts = BroadcastOptions()) override; - std::shared_ptr allreduce( + c10::intrusive_ptr allreduce( std::vector& tensors, const AllreduceOptions& opts = AllreduceOptions()) override; - std::shared_ptr allreduce_coalesced( + c10::intrusive_ptr allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts = AllreduceCoalescedOptions()) override; - std::shared_ptr reduce( + c10::intrusive_ptr reduce( std::vector& tensors, const ReduceOptions& opts = ReduceOptions()) override; - std::shared_ptr allgather( + c10::intrusive_ptr allgather( std::vector>& outputTensors, std::vector& inputTensors, const AllgatherOptions& opts = AllgatherOptions()) override; - std::shared_ptr allgather_base( + c10::intrusive_ptr allgather_base( at::Tensor& outputBuffer, at::Tensor& inputBuffer, const AllgatherOptions& opts = AllgatherOptions()) override; - std::shared_ptr barrier( + c10::intrusive_ptr barrier( const BarrierOptions& opts = BarrierOptions()) override; - std::shared_ptr gather( + c10::intrusive_ptr gather( std::vector>& outputTensors, std::vector& inputTensors, const GatherOptions& opts = GatherOptions()) override; - std::shared_ptr scatter( + c10::intrusive_ptr scatter( std::vector& outputTensors, std::vector>& inputTensors, const ScatterOptions& opts = ScatterOptions()) override; - std::shared_ptr reduce_scatter( + c10::intrusive_ptr reduce_scatter( std::vector& outputTensors, std::vector>& inputTensors, const ReduceScatterOptions& opts = ReduceScatterOptions()) override; - std::shared_ptr send( + c10::intrusive_ptr send( std::vector& tensors, int dstRank, int tag); - std::shared_ptr recv( + c10::intrusive_ptr recv( std::vector& tensors, int srcRank, int tag); - std::shared_ptr recvAnysource( + c10::intrusive_ptr recvAnysource( std::vector& tensor, int tag); diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index d9ddf35ee1df..136efd32fc87 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -1,5 +1,6 @@ #include +#include #include #ifndef _WIN32 #include @@ -59,6 +60,8 @@ constexpr auto kDeprecationWarning = "{} API is being deprecated, please ping " "https://github.com/pytorch/pytorch/issues/46291 " "if you see this warning"; +template +using intrusive_ptr_class_ = py::class_>; // PythonStore is a pybind11 trampoline class to allow a Python // class to inherit from c10d.Store and implement its interface. @@ -1045,7 +1048,7 @@ that adds a prefix to each key inserted to the store. py::call_guard()); #endif - shared_ptr_class_<::c10d::ProcessGroup::Work>(module, "Work") + intrusive_ptr_class_<::c10d::ProcessGroup::Work>(module, "Work") .def("is_completed", &::c10d::ProcessGroup::Work::isCompleted) .def( "is_success", diff --git a/torch/csrc/distributed/rpc/process_group_agent.cpp b/torch/csrc/distributed/rpc/process_group_agent.cpp index 2f29adc8f0c4..13e685b8fe74 100644 --- a/torch/csrc/distributed/rpc/process_group_agent.cpp +++ b/torch/csrc/distributed/rpc/process_group_agent.cpp @@ -398,7 +398,7 @@ void ProcessGroupAgent::handleSend(const SendWork& work) { // ProcessGroup is not thread-safe when sending with the same tag, // hence the lock - std::vector> pendingSends; + std::vector> pendingSends; const auto dst = work.to_.id_; // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) diff --git a/torch/csrc/distributed/rpc/process_group_agent.h b/torch/csrc/distributed/rpc/process_group_agent.h index 1bc8db9ebf20..70fb1b40244d 100644 --- a/torch/csrc/distributed/rpc/process_group_agent.h +++ b/torch/csrc/distributed/rpc/process_group_agent.h @@ -230,14 +230,14 @@ class TORCH_API ProcessGroupAgent : public RpcAgent { // Lock and shared ptr to currently pending work, set in listenloop() and // interruptible in shutdown(). std::mutex recvWorkMutex_; - std::shared_ptr recvWork_; + c10::intrusive_ptr recvWork_; // Map of dst rank to current oustanding sends that we are waiting on. In the // case of a call to ::shutdown() while we are still waiting on these sends, // the pending sends contained in this map will be aborted, allowing the // waiting thread to be unblocked. std::unordered_map< worker_id_t, - std::set>> + std::set>> currentPendingSends_; // Lock to serialize access to the above map. std::mutex pendingSendMutex_; diff --git a/torch/lib/c10d/ProcessGroup.cpp b/torch/lib/c10d/ProcessGroup.cpp index 3521ed42c840..1d0d451f21a9 100644 --- a/torch/lib/c10d/ProcessGroup.cpp +++ b/torch/lib/c10d/ProcessGroup.cpp @@ -164,7 +164,7 @@ ProcessGroup::~ProcessGroup() {} // This is introduced so that implementors of ProcessGroup would not need to // have this implmentation. -std::shared_ptr ProcessGroup::allgather_coalesced( +c10::intrusive_ptr ProcessGroup::allgather_coalesced( std::vector>& /* usused */, std::vector& /* usused */, const AllgatherOptions& /* usused */) { diff --git a/torch/lib/c10d/ProcessGroup.hpp b/torch/lib/c10d/ProcessGroup.hpp index 5e90dccc25c0..63996b516a06 100644 --- a/torch/lib/c10d/ProcessGroup.hpp +++ b/torch/lib/c10d/ProcessGroup.hpp @@ -70,12 +70,11 @@ bool isP2POp(OpType opType); // class ProcessGroup { public: - // Please do not use ProcessGroup::Work API, it is going away, to be // replaced by ivalue::Future. // Python binding for this class might change, please do not assume // this will be bound using pybind. - class Work { + class Work : public torch::CustomClassHolder { public: Work(int rank = -1, OpType opType = OpType::UNKNOWN, const char* profilingTitle = nullptr); @@ -171,25 +170,25 @@ class ProcessGroup { return size_; } - virtual std::shared_ptr broadcast( + virtual c10::intrusive_ptr broadcast( std::vector& data, const BroadcastOptions& opts = BroadcastOptions()) = 0; - virtual std::shared_ptr allreduce( + virtual c10::intrusive_ptr allreduce( std::vector& data, const AllreduceOptions& opts = AllreduceOptions()) = 0; // This will be moved out of ProcessGroup, do not add dependencies on this // function. - virtual std::shared_ptr allreduce_coalesced( + virtual c10::intrusive_ptr allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts = AllreduceCoalescedOptions()) = 0; - virtual std::shared_ptr reduce( + virtual c10::intrusive_ptr reduce( std::vector& tensors, const ReduceOptions& opts = ReduceOptions()) = 0; - virtual std::shared_ptr allgather( + virtual c10::intrusive_ptr allgather( std::vector>& outputTensors, std::vector& inputTensors, const AllgatherOptions& opts = AllgatherOptions()) = 0; @@ -197,7 +196,7 @@ class ProcessGroup { // Gathers a single tensor inputBuffer into a single buffer outputBuffer that // is interpreted as a contigious collection of size inputBuffer * WORLD_SIZE. // For implementers of ProcessGroup API and advanced users only. - virtual std::shared_ptr allgather_base( + virtual c10::intrusive_ptr allgather_base( at::Tensor& outputBuffer, at::Tensor& inputBuffer, const AllgatherOptions& opts = AllgatherOptions()) = 0; @@ -206,27 +205,27 @@ class ProcessGroup { // * do not add dependencies on this function, // * do not implement it in your ProcessGroup, implement allgather_base // instead. - virtual std::shared_ptr allgather_coalesced( + virtual c10::intrusive_ptr allgather_coalesced( std::vector>& outputTensorLists, std::vector& inputTensors, const AllgatherOptions& opts = AllgatherOptions()); - virtual std::shared_ptr gather( + virtual c10::intrusive_ptr gather( std::vector>& outputTensors, std::vector& inputTensors, const GatherOptions& opts = GatherOptions()) = 0; - virtual std::shared_ptr scatter( + virtual c10::intrusive_ptr scatter( std::vector& outputTensors, std::vector>& inputTensors, const ScatterOptions& opts = ScatterOptions()) = 0; - virtual std::shared_ptr reduce_scatter( + virtual c10::intrusive_ptr reduce_scatter( std::vector& outputTensors, std::vector>& inputTensors, const ReduceScatterOptions& opts = ReduceScatterOptions()) = 0; - virtual std::shared_ptr alltoall_base( + virtual c10::intrusive_ptr alltoall_base( at::Tensor& outputTensor, at::Tensor& inputTensor, std::vector& outputSplitSizes, @@ -235,28 +234,28 @@ class ProcessGroup { throw std::runtime_error("ProcessGroup does not support alltoall"); } - virtual std::shared_ptr alltoall( + virtual c10::intrusive_ptr alltoall( std::vector& outputTensors, std::vector& inputTensors, const AllToAllOptions& opts = AllToAllOptions()) { throw std::runtime_error("ProcessGroup does not support alltoall"); } - virtual std::shared_ptr send( + virtual c10::intrusive_ptr send( std::vector& tensors, int dstRank, int tag) = 0; - virtual std::shared_ptr recv( + virtual c10::intrusive_ptr recv( std::vector& tensors, int srcRank, int tag) = 0; - virtual std::shared_ptr recvAnysource( + virtual c10::intrusive_ptr recvAnysource( std::vector& tensors, int tag) = 0; - virtual std::shared_ptr barrier( + virtual c10::intrusive_ptr barrier( const BarrierOptions& opts = BarrierOptions()) = 0; protected: diff --git a/torch/lib/c10d/ProcessGroupGloo.cpp b/torch/lib/c10d/ProcessGroupGloo.cpp index cd3e83e6b714..90c9b695de28 100644 --- a/torch/lib/c10d/ProcessGroupGloo.cpp +++ b/torch/lib/c10d/ProcessGroupGloo.cpp @@ -38,6 +38,7 @@ #endif #include +#include #include #include #include @@ -653,11 +654,11 @@ void ProcessGroupGloo::runLoop(int workerIndex) { AsyncWork::execute(std::move(work)); lock.lock(); - workInProgress_[workerIndex] = nullptr; + workInProgress_[workerIndex].reset(); } } -void ProcessGroupGloo::enqueue(std::shared_ptr work) { +void ProcessGroupGloo::enqueue(c10::intrusive_ptr work) { std::unique_lock lock(workMutex_); workQueue_.push_back(std::move(work)); lock.unlock(); @@ -773,7 +774,7 @@ class AsyncBroadcastCUDAWork : public AsyncBroadcastWork { } // namespace -std::shared_ptr ProcessGroupGloo::broadcast( +c10::intrusive_ptr ProcessGroupGloo::broadcast( std::vector& inputs, const BroadcastOptions& opts) { static auto invalidArgument = [](const std::string& msg) { @@ -796,15 +797,15 @@ std::shared_ptr ProcessGroupGloo::broadcast( invalidArgument(c10::str("unsupported device type ", device.type())); } - std::shared_ptr work; + c10::intrusive_ptr work; auto tag = nextTag(); auto context = getContext(tag); if (device.type() == at::kCPU) { - work = std::make_shared( + work = c10::make_intrusive( std::move(context), inputs, opts.rootRank, opts.rootTensor, tag); #ifdef USE_CUDA } else if (device.type() == at::kCUDA) { - work = std::make_shared( + work = c10::make_intrusive( std::move(context), inputs, opts.rootRank, opts.rootTensor, tag); #endif } else { @@ -1300,7 +1301,7 @@ class AsyncSparseAllreduceCUDAWork : public AsyncSparseAllreduceWork { } // namespace -std::shared_ptr ProcessGroupGloo::allreduce( +c10::intrusive_ptr ProcessGroupGloo::allreduce( std::vector& inputs, const AllreduceOptions& opts) { static auto invalidArgument = [](const std::string& msg) { @@ -1329,15 +1330,15 @@ std::shared_ptr ProcessGroupGloo::allreduce( "(allreduce of sparse tensors only works with ReduceOp.SUM)"); } - std::shared_ptr work; + c10::intrusive_ptr work; auto tag = nextTag(); auto context = getContext(tag); if (device.type() == at::kCPU) { if (layout == c10::kStrided) { - work = std::make_shared( + work = c10::make_intrusive( std::move(context), inputs, opts.reduceOp, tag); } else if (layout == c10::kSparse) { - work = std::make_shared( + work = c10::make_intrusive( std::move(context), inputs, tag); } else { invalidArgument("unsupported layout"); @@ -1345,10 +1346,10 @@ std::shared_ptr ProcessGroupGloo::allreduce( #ifdef USE_CUDA } else if (device.type() == at::kCUDA) { if (layout == c10::kStrided) { - work = std::make_shared( + work = c10::make_intrusive( std::move(context), inputs, opts.reduceOp, tag); } else if (layout == c10::kSparse) { - work = std::make_shared( + work = c10::make_intrusive( std::move(context), inputs, tag); } else { invalidArgument("unsupported layout"); @@ -1362,7 +1363,7 @@ std::shared_ptr ProcessGroupGloo::allreduce( return work; } -std::shared_ptr ProcessGroupGloo::allreduce_coalesced( +c10::intrusive_ptr ProcessGroupGloo::allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts) { static auto invalidArgument = [](const std::string& msg) { @@ -1405,12 +1406,12 @@ std::shared_ptr ProcessGroupGloo::allreduce_coalesced( invalidArgument("unsupported layout"); } - std::shared_ptr work; + c10::intrusive_ptr work; const uint32_t tag = nextTag(); std::shared_ptr context = getContext(tag); if (device.type() == c10::kCPU) { if (layout == c10::kStrided) { - work = std::make_shared( + work = c10::make_intrusive( std::move(context), tensors, opts.reduceOp, tag); } else { invalidArgument("unsupported layout"); @@ -1538,7 +1539,7 @@ class AsyncReduceCUDAWork : public AsyncReduceWork { } // namespace -std::shared_ptr ProcessGroupGloo::reduce( +c10::intrusive_ptr ProcessGroupGloo::reduce( std::vector& inputs, const ReduceOptions& opts) { static auto invalidArgument = [](const std::string& msg) { @@ -1561,11 +1562,11 @@ std::shared_ptr ProcessGroupGloo::reduce( invalidArgument(c10::str("unsupported device type ", device.type())); } - std::shared_ptr work; + c10::intrusive_ptr work; auto tag = nextTag(); auto context = getContext(tag); if (device.type() == at::kCPU) { - work = std::make_shared( + work = c10::make_intrusive( std::move(context), inputs, opts.rootRank, @@ -1574,7 +1575,7 @@ std::shared_ptr ProcessGroupGloo::reduce( tag); #ifdef USE_CUDA } else if (device.type() == at::kCUDA) { - work = std::make_shared( + work = c10::make_intrusive( std::move(context), inputs, opts.rootRank, @@ -1720,7 +1721,7 @@ class AsyncAllgatherCUDAWork : public AsyncAllgatherWork { // Note: current CUDA implementation holds the assumption that the // tensors in the nested output tensor vectors are on the same device. -std::shared_ptr ProcessGroupGloo::allgather( +c10::intrusive_ptr ProcessGroupGloo::allgather( std::vector>& outputs, std::vector& inputs, const AllgatherOptions& opts) { @@ -1769,15 +1770,15 @@ std::shared_ptr ProcessGroupGloo::allgather( invalidArgument(c10::str("unsupported device type ", device.type())); } - std::shared_ptr work; + c10::intrusive_ptr work; auto tag = nextTag(); auto context = getContext(tag); if (device.type() == at::kCPU) { - work = std::make_shared( + work = c10::make_intrusive( std::move(context), outputs, inputs, tag); #ifdef USE_CUDA } else if (device.type() == at::kCUDA) { - work = std::make_shared( + work = c10::make_intrusive( std::move(context), outputs, inputs, tag); #endif } else { @@ -1852,7 +1853,7 @@ class AsyncAllgatherCoalescedWork : public ProcessGroupGloo::AsyncWork { } // namespace -std::shared_ptr ProcessGroupGloo::allgather_coalesced( +c10::intrusive_ptr ProcessGroupGloo::allgather_coalesced( std::vector>& output_lists, std::vector& input_list, const AllgatherOptions& /* unused */) { @@ -1902,13 +1903,13 @@ std::shared_ptr ProcessGroupGloo::allgather_coalesced( auto tag = nextTag(); auto context = getContext(tag); - auto work = std::make_shared( + auto work = c10::make_intrusive( std::move(context), output_lists, input_list, tag); enqueue(work); return work; } -std::shared_ptr ProcessGroupGloo::allgather_base( +c10::intrusive_ptr ProcessGroupGloo::allgather_base( at::Tensor& /*unused */, at::Tensor& /*unused */, const AllgatherOptions& /*unused */) { @@ -2057,7 +2058,7 @@ class AsyncGatherCUDAWork : public AsyncGatherWork { } // namespace -std::shared_ptr ProcessGroupGloo::gather( +c10::intrusive_ptr ProcessGroupGloo::gather( std::vector>& outputs, std::vector& inputs, const GatherOptions& opts) { @@ -2103,15 +2104,15 @@ std::shared_ptr ProcessGroupGloo::gather( invalidArgument(c10::str("unsupported device type ", device.type())); } - std::shared_ptr work; + c10::intrusive_ptr work; auto tag = nextTag(); auto context = getContext(tag); if (device.type() == at::kCPU) { - work = std::make_shared( + work = c10::make_intrusive( std::move(context), outputs, inputs, opts.rootRank, tag); #ifdef USE_CUDA } else if (device.type() == at::kCUDA) { - work = std::make_shared( + work = c10::make_intrusive( std::move(context), outputs, inputs, opts.rootRank, tag); #endif } else { @@ -2245,7 +2246,7 @@ class AsyncScatterCUDAWork : public AsyncScatterWork { } // namespace -std::shared_ptr ProcessGroupGloo::scatter( +c10::intrusive_ptr ProcessGroupGloo::scatter( std::vector& outputs, std::vector>& inputs, const ScatterOptions& opts) { @@ -2290,15 +2291,15 @@ std::shared_ptr ProcessGroupGloo::scatter( invalidArgument(c10::str("unsupported device type ", device.type())); } - std::shared_ptr work; + c10::intrusive_ptr work; auto tag = nextTag(); auto context = getContext(tag); if (device.type() == at::kCPU) { - work = std::make_shared( + work = c10::make_intrusive( std::move(context), outputs, inputs, opts.rootRank, tag); #ifdef USE_CUDA } else if (device.type() == at::kCUDA) { - work = std::make_shared( + work = c10::make_intrusive( std::move(context), outputs, inputs, opts.rootRank, tag); #endif } else { @@ -2308,7 +2309,7 @@ std::shared_ptr ProcessGroupGloo::scatter( return work; } -std::shared_ptr ProcessGroupGloo::reduce_scatter( +c10::intrusive_ptr ProcessGroupGloo::reduce_scatter( std::vector& outputs, std::vector>& inputs, const ReduceScatterOptions& opts) { @@ -2443,7 +2444,7 @@ class AsyncAlltoallCUDAWork : public AsyncAlltoallWork { } // namespace -std::shared_ptr ProcessGroupGloo::alltoall_base( +c10::intrusive_ptr ProcessGroupGloo::alltoall_base( at::Tensor& outputTensor, at::Tensor& inputTensor, std::vector& outputCounts, @@ -2460,12 +2461,12 @@ std::shared_ptr ProcessGroupGloo::alltoall_base( assertDense(invalidArgument, {inputTensor}); const auto& device = outputTensor.device(); - std::shared_ptr work; + c10::intrusive_ptr work; auto tag = nextTag(); auto context = getContext(tag); if (device.type() == at::kCPU) { - work = std::make_shared( + work = c10::make_intrusive( std::move(context), outputTensor, inputTensor, @@ -2474,7 +2475,7 @@ std::shared_ptr ProcessGroupGloo::alltoall_base( tag); #ifdef USE_CUDA } else if (device.type() == at::kCUDA) { - work = std::make_shared( + work = c10::make_intrusive( std::move(context), outputTensor, inputTensor, @@ -2510,7 +2511,7 @@ uint32_t checkTag(int32_t tag) { return (uint32_t)tag; } -std::shared_ptr ProcessGroupGloo::send( +c10::intrusive_ptr ProcessGroupGloo::send( std::vector& tensors, int dstRank, int tag) { @@ -2526,10 +2527,10 @@ std::shared_ptr ProcessGroupGloo::send( // The work captures the tensor to prevent it being deallocated and // the unbound buffer to synchronize on completion of the send. - return std::make_shared(tensor, std::move(buf)); + return c10::make_intrusive(tensor, std::move(buf)); } -std::shared_ptr ProcessGroupGloo::recv( +c10::intrusive_ptr ProcessGroupGloo::recv( std::vector& tensors, int srcRank, int tag) { @@ -2545,10 +2546,10 @@ std::shared_ptr ProcessGroupGloo::recv( // The work captures the tensor to prevent it being deallocated and // the unbound buffer to synchronize on completion of the recv. - return std::make_shared(tensor, std::move(buf)); + return c10::make_intrusive(tensor, std::move(buf)); } -std::shared_ptr ProcessGroupGloo::recvAnysource( +c10::intrusive_ptr ProcessGroupGloo::recvAnysource( std::vector& tensors, int tag) { auto& tensor = checkSingleTensor(tensors); @@ -2573,7 +2574,7 @@ std::shared_ptr ProcessGroupGloo::recvAnysource( // The work captures the tensor to prevent it being deallocated and // the unbound buffer to synchronize on completion of the recv. - return std::make_shared(tensor, std::move(buf)); + return c10::make_intrusive(tensor, std::move(buf)); } namespace { @@ -2582,13 +2583,13 @@ class AsyncBarrierWork : public ProcessGroupGloo::AsyncWork { public: AsyncBarrierWork( const std::shared_ptr& context, - std::vector> priorWork, + std::vector> priorWork, uint32_t tag) : ProcessGroupGloo::AsyncWork("gloo:barrier"), context(context), priorWork(std::move(priorWork)), tag(tag) {} std::shared_ptr context; - std::vector> priorWork; + std::vector> priorWork; const uint32_t tag; void run() override { @@ -2608,9 +2609,9 @@ class AsyncBarrierWork : public ProcessGroupGloo::AsyncWork { } // namespace -std::shared_ptr ProcessGroupGloo::barrier( +c10::intrusive_ptr ProcessGroupGloo::barrier( const BarrierOptions& opts) { - std::vector> priorWork; + std::vector> priorWork; // Snapshot all in progress and pending work as weak_ptr. // When executing a barrier, we need to ensure that all prior work @@ -2624,7 +2625,7 @@ std::shared_ptr ProcessGroupGloo::barrier( auto tag = nextTag(); auto context = getContext(tag); - auto work = std::make_shared( + auto work = c10::make_intrusive( std::move(context), std::move(priorWork), tag); enqueue(work); return work; diff --git a/torch/lib/c10d/ProcessGroupGloo.hpp b/torch/lib/c10d/ProcessGroupGloo.hpp index 31664ad0b6cf..74fd0f6e5165 100644 --- a/torch/lib/c10d/ProcessGroupGloo.hpp +++ b/torch/lib/c10d/ProcessGroupGloo.hpp @@ -70,7 +70,7 @@ class ProcessGroupGloo : public ProcessGroup { public: AsyncWork(const char* profilingTitle = nullptr): ProcessGroup::Work(-1, OpType::UNKNOWN, profilingTitle) {} - static void execute(std::shared_ptr work) { + static void execute(c10::intrusive_ptr work) { std::exception_ptr eptr; try { work->run(); @@ -159,75 +159,75 @@ class ProcessGroupGloo : public ProcessGroup { virtual ~ProcessGroupGloo(); - std::shared_ptr broadcast( + c10::intrusive_ptr broadcast( std::vector& tensors, const BroadcastOptions& opts = BroadcastOptions()) override; - std::shared_ptr allreduce( + c10::intrusive_ptr allreduce( std::vector& tensors, const AllreduceOptions& opts = AllreduceOptions()) override; - std::shared_ptr allreduce_coalesced( + c10::intrusive_ptr allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts = AllreduceCoalescedOptions()) override; - std::shared_ptr reduce( + c10::intrusive_ptr reduce( std::vector& tensors, const ReduceOptions& opts = ReduceOptions()) override; - std::shared_ptr allgather( + c10::intrusive_ptr allgather( std::vector>& outputs, std::vector& inputs, const AllgatherOptions& opts = AllgatherOptions()) override; - std::shared_ptr allgather_base( + c10::intrusive_ptr allgather_base( at::Tensor& outputBuffer, at::Tensor& inputBuffer, const AllgatherOptions& opts = AllgatherOptions()) override; - std::shared_ptr allgather_coalesced( + c10::intrusive_ptr allgather_coalesced( std::vector>& output_lists, std::vector& input_list, const AllgatherOptions& opts = AllgatherOptions()) override; - std::shared_ptr gather( + c10::intrusive_ptr gather( std::vector>& outputs, std::vector& inputs, const GatherOptions& opts = GatherOptions()) override; - std::shared_ptr scatter( + c10::intrusive_ptr scatter( std::vector& outputs, std::vector>& inputs, const ScatterOptions& opts = ScatterOptions()) override; - std::shared_ptr reduce_scatter( + c10::intrusive_ptr reduce_scatter( std::vector& outputs, std::vector>& inputs, const ReduceScatterOptions& opts = ReduceScatterOptions()) override; - std::shared_ptr alltoall_base( + c10::intrusive_ptr alltoall_base( at::Tensor& outputTensor, at::Tensor& inputTensor, std::vector& outputCounts, std::vector& inputCounts, const AllToAllOptions& opts = AllToAllOptions()) override; - std::shared_ptr send( + c10::intrusive_ptr send( std::vector& tensors, int dstRank, int tag) override; - std::shared_ptr recv( + c10::intrusive_ptr recv( std::vector& tensors, int srcRank, int tag) override; - std::shared_ptr recvAnysource( + c10::intrusive_ptr recvAnysource( std::vector& tensors, int tag) override; - std::shared_ptr barrier( + c10::intrusive_ptr barrier( const BarrierOptions& opts = BarrierOptions()) override; protected: @@ -258,7 +258,7 @@ class ProcessGroupGloo : public ProcessGroup { void runLoop(int workerIndex); // Queue work to run on worker thread. - void enqueue(std::shared_ptr work); + void enqueue(c10::intrusive_ptr work); // Keep both a queue of pending work, and a vector with in progress work. // Both of these can only be mutated when holding the queue lock. @@ -266,8 +266,8 @@ class ProcessGroupGloo : public ProcessGroup { // to all in progress and pending work when executing a barrier. // When executing a barrier, we need to ensure that all prior work // has completed before completing itself. - std::deque> workQueue_; - std::vector> workInProgress_; + std::deque> workQueue_; + std::vector> workInProgress_; std::mutex workMutex_; std::condition_variable workProduceCV_; std::condition_variable workConsumeCV_; diff --git a/torch/lib/c10d/ProcessGroupMPI.cpp b/torch/lib/c10d/ProcessGroupMPI.cpp index d3e79a1dd424..5f9d0be41b8f 100644 --- a/torch/lib/c10d/ProcessGroupMPI.cpp +++ b/torch/lib/c10d/ProcessGroupMPI.cpp @@ -308,9 +308,9 @@ void ProcessGroupMPI::runLoop() { } } -std::shared_ptr ProcessGroupMPI::enqueue( +c10::intrusive_ptr ProcessGroupMPI::enqueue( std::unique_ptr entry) { - auto work = std::make_shared(); + auto work = c10::make_intrusive(); std::unique_lock lock(pgMutex_); queue_.push_back(std::make_tuple(std::move(entry), work)); lock.unlock(); @@ -318,7 +318,7 @@ std::shared_ptr ProcessGroupMPI::enqueue( return work; } -std::shared_ptr ProcessGroupMPI::broadcast( +c10::intrusive_ptr ProcessGroupMPI::broadcast( std::vector& tensors, const BroadcastOptions& opts) { checkSingleTensor(tensors); @@ -339,7 +339,7 @@ std::shared_ptr ProcessGroupMPI::broadcast( return enqueue(std::move(entry)); } -std::shared_ptr ProcessGroupMPI::allreduce( +c10::intrusive_ptr ProcessGroupMPI::allreduce( std::vector& tensors, const AllreduceOptions& opts) { checkSingleTensor(tensors); @@ -362,14 +362,14 @@ std::shared_ptr ProcessGroupMPI::allreduce( return enqueue(std::move(entry)); } -std::shared_ptr ProcessGroupMPI::allreduce_coalesced( +c10::intrusive_ptr ProcessGroupMPI::allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts) { throw std::runtime_error( "allreduce_coalesced is currently not supported with MPI"); } -std::shared_ptr ProcessGroupMPI::reduce( +c10::intrusive_ptr ProcessGroupMPI::reduce( std::vector& tensors, const ReduceOptions& opts) { checkSingleTensor(tensors); @@ -397,7 +397,7 @@ std::shared_ptr ProcessGroupMPI::reduce( return enqueue(std::move(entry)); } -std::shared_ptr ProcessGroupMPI::allgather( +c10::intrusive_ptr ProcessGroupMPI::allgather( std::vector>& outputTensors, std::vector& inputTensors, const AllgatherOptions& opts) { @@ -441,7 +441,7 @@ std::shared_ptr ProcessGroupMPI::allgather( return enqueue(std::move(entry)); } -std::shared_ptr ProcessGroupMPI::allgather_coalesced( +c10::intrusive_ptr ProcessGroupMPI::allgather_coalesced( std::vector>& /* unused */, std::vector& /* unused */, const AllgatherOptions& /* unused */) { @@ -449,7 +449,7 @@ std::shared_ptr ProcessGroupMPI::allgather_coalesced( "ProcessGroupMPI does not support allgather_coalesced"); } -std::shared_ptr ProcessGroupMPI::gather( +c10::intrusive_ptr ProcessGroupMPI::gather( std::vector>& outputTensors, std::vector& inputTensors, const GatherOptions& opts) { @@ -516,7 +516,7 @@ std::shared_ptr ProcessGroupMPI::gather( } } -std::shared_ptr ProcessGroupMPI::scatter( +c10::intrusive_ptr ProcessGroupMPI::scatter( std::vector& outputTensors, std::vector>& inputTensors, const ScatterOptions& opts) { @@ -582,14 +582,14 @@ std::shared_ptr ProcessGroupMPI::scatter( } } -std::shared_ptr ProcessGroupMPI::reduce_scatter( +c10::intrusive_ptr ProcessGroupMPI::reduce_scatter( std::vector& outputTensors, std::vector>& inputTensors, const ReduceScatterOptions& opts) { throw std::runtime_error("ProcessGroupMPI does not support reduce_scatter"); } -std::shared_ptr ProcessGroupMPI::alltoall_base( +c10::intrusive_ptr ProcessGroupMPI::alltoall_base( at::Tensor& outputTensor, at::Tensor& inputTensor, std::vector& outputSplitSizes, @@ -665,7 +665,7 @@ std::shared_ptr ProcessGroupMPI::alltoall_base( return enqueue(std::move(entry)); } } -std::shared_ptr ProcessGroupMPI::alltoall( +c10::intrusive_ptr ProcessGroupMPI::alltoall( std::vector& outputTensors, std::vector& inputTensors, const AllToAllOptions& opts) { @@ -722,7 +722,7 @@ std::shared_ptr ProcessGroupMPI::alltoall( return enqueue(std::move(entry)); } -std::shared_ptr ProcessGroupMPI::send( +c10::intrusive_ptr ProcessGroupMPI::send( std::vector& tensors, int dstRank, int tag) { @@ -744,10 +744,10 @@ std::shared_ptr ProcessGroupMPI::send( &request)); } - return std::make_shared(tensor, request); + return c10::make_intrusive(tensor, request); } -std::shared_ptr ProcessGroupMPI::recv( +c10::intrusive_ptr ProcessGroupMPI::recv( std::vector& tensors, int srcRank, int tag) { @@ -769,10 +769,10 @@ std::shared_ptr ProcessGroupMPI::recv( &request)); } - return std::make_shared(tensor, request); + return c10::make_intrusive(tensor, request); } -std::shared_ptr ProcessGroupMPI::recvAnysource( +c10::intrusive_ptr ProcessGroupMPI::recvAnysource( std::vector& tensors, int tag) { checkSingleTensor(tensors); @@ -793,10 +793,10 @@ std::shared_ptr ProcessGroupMPI::recvAnysource( &request)); } - return std::make_shared(tensor, request); + return c10::make_intrusive(tensor, request); } -std::shared_ptr ProcessGroupMPI::barrier( +c10::intrusive_ptr ProcessGroupMPI::barrier( const BarrierOptions& opts) { std::function&)> runFunc = [this](std::unique_ptr& entry) { @@ -808,7 +808,7 @@ std::shared_ptr ProcessGroupMPI::barrier( return enqueue(std::move(entry)); } -std::shared_ptr ProcessGroupMPI::allgather_base( +c10::intrusive_ptr ProcessGroupMPI::allgather_base( at::Tensor& /*unused */, at::Tensor& /*unused */, const AllgatherOptions& /*unused */) { diff --git a/torch/lib/c10d/ProcessGroupMPI.hpp b/torch/lib/c10d/ProcessGroupMPI.hpp index 342fe87001a0..48d95eada887 100644 --- a/torch/lib/c10d/ProcessGroupMPI.hpp +++ b/torch/lib/c10d/ProcessGroupMPI.hpp @@ -108,80 +108,80 @@ class ProcessGroupMPI : public ProcessGroup { // Abort the MPI program, needs to be called when exception is detected void abort(); - std::shared_ptr broadcast( + c10::intrusive_ptr broadcast( std::vector& data, const BroadcastOptions& opts = BroadcastOptions()) override; - std::shared_ptr allreduce( + c10::intrusive_ptr allreduce( std::vector& tensors, const AllreduceOptions& opts = AllreduceOptions()) override; - std::shared_ptr allreduce_coalesced( + c10::intrusive_ptr allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts = AllreduceCoalescedOptions()) override; - std::shared_ptr reduce( + c10::intrusive_ptr reduce( std::vector& tensors, const ReduceOptions& opts = ReduceOptions()) override; - std::shared_ptr allgather( + c10::intrusive_ptr allgather( std::vector>& outputTensors, std::vector& inputTensors, const AllgatherOptions& opts = AllgatherOptions()) override; - std::shared_ptr allgather_base( + c10::intrusive_ptr allgather_base( at::Tensor& outputbuffer, at::Tensor& inputbuffer, const AllgatherOptions& opts = AllgatherOptions()) override; - std::shared_ptr allgather_coalesced( + c10::intrusive_ptr allgather_coalesced( std::vector>& outputTensorLists, std::vector& inputTensors, const AllgatherOptions& opts = AllgatherOptions()) override; - std::shared_ptr gather( + c10::intrusive_ptr gather( std::vector>& outputTensors, std::vector& inputTensors, const GatherOptions& opts = GatherOptions()) override; - std::shared_ptr scatter( + c10::intrusive_ptr scatter( std::vector& outputTensors, std::vector>& inputTensors, const ScatterOptions& opts = ScatterOptions()) override; - std::shared_ptr reduce_scatter( + c10::intrusive_ptr reduce_scatter( std::vector& outputTensors, std::vector>& inputTensors, const ReduceScatterOptions& opts = ReduceScatterOptions()) override; - std::shared_ptr alltoall_base( + c10::intrusive_ptr alltoall_base( at::Tensor& outputTensor, at::Tensor& inputTensor, std::vector& outputSplitSizes, std::vector& inputSplitSizes, const AllToAllOptions& opts = AllToAllOptions()) override; - std::shared_ptr alltoall( + c10::intrusive_ptr alltoall( std::vector& outputTensors, std::vector& inputTensors, const AllToAllOptions& opts = AllToAllOptions()) override; - std::shared_ptr send( + c10::intrusive_ptr send( std::vector& tensors, int dstRank, int tag); - std::shared_ptr recv( + c10::intrusive_ptr recv( std::vector& tensors, int srcRank, int tag); - std::shared_ptr recvAnysource( + c10::intrusive_ptr recvAnysource( std::vector& tensor, int tag); - std::shared_ptr barrier( + c10::intrusive_ptr barrier( const BarrierOptions& opts = BarrierOptions()) override; // Creating a new ProcessGroupMPI, will initiialize MPI if not initialized @@ -190,13 +190,13 @@ class ProcessGroupMPI : public ProcessGroup { protected: using WorkType = - std::tuple, std::shared_ptr>; + std::tuple, c10::intrusive_ptr>; // Worker thread loop void runLoop(); // Helper function that is called by the destructor void destroy(); - std::shared_ptr enqueue(std::unique_ptr entry); + c10::intrusive_ptr enqueue(std::unique_ptr entry); bool stop_; diff --git a/torch/lib/c10d/ProcessGroupNCCL.cpp b/torch/lib/c10d/ProcessGroupNCCL.cpp index ba0b4b36c77d..d1a4c6cb97ad 100644 --- a/torch/lib/c10d/ProcessGroupNCCL.cpp +++ b/torch/lib/c10d/ProcessGroupNCCL.cpp @@ -984,12 +984,12 @@ std::vector flatten_for_scatter_gather( } // namespace -std::shared_ptr ProcessGroupNCCL::initWork( +c10::intrusive_ptr ProcessGroupNCCL::initWork( std::vector devices, int rank, OpType opType, const char* profilingTitle) { - return std::make_shared(devices, rank, opType, profilingTitle); + return c10::make_intrusive(devices, rank, opType, profilingTitle); } std::vector ProcessGroupNCCL::WorkNCCL::result() { @@ -1012,7 +1012,7 @@ c10::intrusive_ptr ProcessGroupNCCL::WorkNCCL:: } void ProcessGroupNCCL::workEnqueue( - std::shared_ptr work) { + c10::intrusive_ptr work) { if (!terminateProcessGroup_.load()) { std::lock_guard lock(workMetaListMutex_); // Avoid view tensors to be processed in cleanup thread. @@ -1027,7 +1027,7 @@ ProcessGroupNCCL::Options::Options() isHighPriorityStream(false) {} template -std::shared_ptr ProcessGroupNCCL::collective( +c10::intrusive_ptr ProcessGroupNCCL::collective( std::vector& inputs, std::vector& outputs, Fn fn, @@ -1114,7 +1114,7 @@ std::shared_ptr ProcessGroupNCCL::collective( } template -std::shared_ptr ProcessGroupNCCL::pointToPoint( +c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( std::vector& tensors, Fn fn, int peer, @@ -1186,7 +1186,7 @@ std::shared_ptr ProcessGroupNCCL::pointToPoint( } template -std::shared_ptr ProcessGroupNCCL::collective( +c10::intrusive_ptr ProcessGroupNCCL::collective( std::vector& inputs, std::vector& outputs, Fn fn, @@ -1203,7 +1203,7 @@ std::shared_ptr ProcessGroupNCCL::collective( } template -std::shared_ptr ProcessGroupNCCL::pointToPoint( +c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( std::vector& tensor, Fn fn, int peer, @@ -1217,7 +1217,7 @@ std::shared_ptr ProcessGroupNCCL::pointToPoint( [](std::vector&) {}); } -std::shared_ptr ProcessGroupNCCL::allreduce( +c10::intrusive_ptr ProcessGroupNCCL::allreduce( std::vector& tensors, const AllreduceOptions& opts) { check_gpu_tensors(tensors); @@ -1242,14 +1242,14 @@ std::shared_ptr ProcessGroupNCCL::allreduce( "nccl:all_reduce"); } -std::shared_ptr ProcessGroupNCCL::allreduce_coalesced( +c10::intrusive_ptr ProcessGroupNCCL::allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts) { throw std::runtime_error( "allreduce_coalesced is currently not supported with NCCL"); } -std::shared_ptr ProcessGroupNCCL::broadcast( +c10::intrusive_ptr ProcessGroupNCCL::broadcast( std::vector& tensors, const BroadcastOptions& opts) { check_gpu_tensors(tensors); @@ -1274,7 +1274,7 @@ std::shared_ptr ProcessGroupNCCL::broadcast( "nccl:broadcast"); } -std::shared_ptr ProcessGroupNCCL::reduce( +c10::intrusive_ptr ProcessGroupNCCL::reduce( std::vector& tensors, const ReduceOptions& opts) { check_gpu_tensors(tensors); @@ -1301,7 +1301,7 @@ std::shared_ptr ProcessGroupNCCL::reduce( "nccl:reduce"); } -std::shared_ptr ProcessGroupNCCL::allgather( +c10::intrusive_ptr ProcessGroupNCCL::allgather( std::vector>& outputTensors, std::vector& inputTensors, const AllgatherOptions& opts) { @@ -1346,7 +1346,7 @@ std::shared_ptr ProcessGroupNCCL::allgather( "nccl:all_gather"); } -std::shared_ptr ProcessGroupNCCL::allgather_coalesced( +c10::intrusive_ptr ProcessGroupNCCL::allgather_coalesced( std::vector>& /* unused */, std::vector& /* unused */, const AllgatherOptions& /* unused */) { @@ -1354,7 +1354,7 @@ std::shared_ptr ProcessGroupNCCL::allgather_coalesced( "ProcessGroupNCCL does not support allgather_coalesced"); } -std::shared_ptr ProcessGroupNCCL::reduce_scatter( +c10::intrusive_ptr ProcessGroupNCCL::reduce_scatter( std::vector& outputTensors, std::vector>& inputTensors, const ReduceScatterOptions& opts) { @@ -1400,7 +1400,7 @@ std::shared_ptr ProcessGroupNCCL::reduce_scatter( "nccl:reduce_scatter"); } -std::shared_ptr ProcessGroupNCCL::barrier( +c10::intrusive_ptr ProcessGroupNCCL::barrier( const BarrierOptions& opts) { std::vector devices; if (usedDeviceIdxs_.empty()) { @@ -1441,7 +1441,7 @@ std::shared_ptr ProcessGroupNCCL::barrier( } #ifdef ENABLE_NCCL_P2P_SUPPORT -std::shared_ptr ProcessGroupNCCL::alltoall_base( +c10::intrusive_ptr ProcessGroupNCCL::alltoall_base( at::Tensor& outputTensor, at::Tensor& inputTensor, std::vector& outputSplitSizes, @@ -1512,7 +1512,7 @@ std::shared_ptr ProcessGroupNCCL::alltoall_base( } } -std::shared_ptr ProcessGroupNCCL::send( +c10::intrusive_ptr ProcessGroupNCCL::send( std::vector& tensors, int dstRank, int /* unused */) { @@ -1531,7 +1531,7 @@ std::shared_ptr ProcessGroupNCCL::send( return ret; } -std::shared_ptr ProcessGroupNCCL::recv( +c10::intrusive_ptr ProcessGroupNCCL::recv( std::vector& tensors, int srcRank, int /* unused */) { @@ -1550,7 +1550,7 @@ std::shared_ptr ProcessGroupNCCL::recv( return ret; } #else -std::shared_ptr ProcessGroupNCCL::alltoall_base( +c10::intrusive_ptr ProcessGroupNCCL::alltoall_base( at::Tensor& /* unused */, at::Tensor& /* unused */, std::vector& /* unused */, @@ -1560,7 +1560,7 @@ std::shared_ptr ProcessGroupNCCL::alltoall_base( "ProcessGroupNCCL only supports alltoall* for NCCL lib version >= 2.7.0"); } -std::shared_ptr ProcessGroupNCCL::send( +c10::intrusive_ptr ProcessGroupNCCL::send( std::vector& /* unused */, int /* unused */, int /* unused */) { @@ -1568,7 +1568,7 @@ std::shared_ptr ProcessGroupNCCL::send( "ProcessGroupNCCL only supports send for NCCL lib version >= 2.7.0"); } -std::shared_ptr ProcessGroupNCCL::recv( +c10::intrusive_ptr ProcessGroupNCCL::recv( std::vector& /* unused */, int /* unused */, int /* unused */) { @@ -1591,34 +1591,34 @@ void ProcessGroupNCCL::groupEnd() { --ncclActiveGroupCounter_; } -std::shared_ptr ProcessGroupNCCL::alltoall( +c10::intrusive_ptr ProcessGroupNCCL::alltoall( std::vector& /* unused */, std::vector& /* unused */, const AllToAllOptions& /* unused */) { throw std::runtime_error("ProcessGroupNCCL does not support alltoall"); } -std::shared_ptr ProcessGroupNCCL::gather( +c10::intrusive_ptr ProcessGroupNCCL::gather( std::vector>& /* unused */, std::vector& /* unused */, const GatherOptions& /* unused */) { throw std::runtime_error("ProcessGroupNCCL does not support gather"); } -std::shared_ptr ProcessGroupNCCL::scatter( +c10::intrusive_ptr ProcessGroupNCCL::scatter( std::vector& /* unused */, std::vector>& /* unused */, const ScatterOptions& /* unused */) { throw std::runtime_error("ProcessGroupNCCL does not support scatter"); } -std::shared_ptr ProcessGroupNCCL::recvAnysource( +c10::intrusive_ptr ProcessGroupNCCL::recvAnysource( std::vector& /* unused */, int /* unused */) { throw std::runtime_error("ProcessGroupNCCL does not support recvAnysource"); } -std::shared_ptr ProcessGroupNCCL::allgather_base( +c10::intrusive_ptr ProcessGroupNCCL::allgather_base( at::Tensor& /*unused */, at::Tensor& /*unused */, const AllgatherOptions& /*unused */) { diff --git a/torch/lib/c10d/ProcessGroupNCCL.hpp b/torch/lib/c10d/ProcessGroupNCCL.hpp index 1520604629f2..59f06fda1ec1 100644 --- a/torch/lib/c10d/ProcessGroupNCCL.hpp +++ b/torch/lib/c10d/ProcessGroupNCCL.hpp @@ -65,7 +65,7 @@ constexpr const char* NCCL_ASYNC_ERROR_HANDLING = "NCCL_ASYNC_ERROR_HANDLING"; class ProcessGroupNCCL : public ProcessGroup { public: class WorkNCCL : public ProcessGroup::Work, - public std::enable_shared_from_this { + public std::enable_shared_from_this { public: // Constructor takes a list of CUDA devices WorkNCCL(const std::vector& devices, int rank, OpType opType, const char* profilingTitle = nullptr); @@ -411,64 +411,64 @@ class ProcessGroupNCCL : public ProcessGroup { virtual ~ProcessGroupNCCL(); - std::shared_ptr broadcast( + c10::intrusive_ptr broadcast( std::vector& tensors, const BroadcastOptions& opts = BroadcastOptions()) override; - std::shared_ptr allreduce( + c10::intrusive_ptr allreduce( std::vector& tensors, const AllreduceOptions& opts = AllreduceOptions()) override; - std::shared_ptr allreduce_coalesced( + c10::intrusive_ptr allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts = AllreduceCoalescedOptions()) override; - std::shared_ptr reduce( + c10::intrusive_ptr reduce( std::vector& tensors, const ReduceOptions& opts = ReduceOptions()) override; - std::shared_ptr allgather( + c10::intrusive_ptr allgather( std::vector>& outputTensors, std::vector& inputTensors, const AllgatherOptions& opts = AllgatherOptions()) override; - std::shared_ptr allgather_base( + c10::intrusive_ptr allgather_base( at::Tensor& outputbuffer, at::Tensor& inputbuffer, const AllgatherOptions& opts = AllgatherOptions()) override; - std::shared_ptr allgather_coalesced( + c10::intrusive_ptr allgather_coalesced( std::vector>& outputTensorLists, std::vector& inputTensors, const AllgatherOptions& opts = AllgatherOptions()) override; - std::shared_ptr reduce_scatter( + c10::intrusive_ptr reduce_scatter( std::vector& outputTensors, std::vector>& inputTensors, const ReduceScatterOptions& opts = ReduceScatterOptions()) override; - std::shared_ptr barrier( + c10::intrusive_ptr barrier( const BarrierOptions& opts = BarrierOptions()) override; - std::shared_ptr alltoall_base( + c10::intrusive_ptr alltoall_base( at::Tensor& outputTensor, at::Tensor& inputTensor, std::vector& outputSplitSizes, std::vector& inputSplitSizes, const AllToAllOptions& opts = AllToAllOptions()) override; - std::shared_ptr alltoall( + c10::intrusive_ptr alltoall( std::vector& outputTensors, std::vector& inputTensors, const AllToAllOptions& opts = AllToAllOptions()) override; - std::shared_ptr send( + c10::intrusive_ptr send( std::vector& tensors, int dstRank, int tag) override; - std::shared_ptr recv( + c10::intrusive_ptr recv( std::vector& tensors, int srcRank, int tag) override; @@ -478,17 +478,17 @@ class ProcessGroupNCCL : public ProcessGroup { static void groupEnd(); // Unsupported Ops - std::shared_ptr gather( + c10::intrusive_ptr gather( std::vector>& outputTensors, std::vector& inputTensors, const GatherOptions& opts = GatherOptions()) override; - std::shared_ptr scatter( + c10::intrusive_ptr scatter( std::vector& outputTensors, std::vector>& inputTensors, const ScatterOptions& opts = ScatterOptions()) override; - std::shared_ptr recvAnysource( + c10::intrusive_ptr recvAnysource( std::vector& tensors, int tag) override; @@ -515,7 +515,7 @@ class ProcessGroupNCCL : public ProcessGroup { virtual std::exception_ptr checkForNCCLErrors( const std::vector>& ncclComms); - virtual std::shared_ptr initWork( + virtual c10::intrusive_ptr initWork( std::vector devices, int rank, OpType opType, @@ -529,14 +529,14 @@ class ProcessGroupNCCL : public ProcessGroup { // ncclComm_t, at::cuda::CUDAStream&); // void {pre,post}(std::vector); template - std::shared_ptr collective( + c10::intrusive_ptr collective( std::vector& input, std::vector& output, Fn fn, OpType opType, const char* profilingTitle = nullptr); template - std::shared_ptr collective( + c10::intrusive_ptr collective( std::vector& input, std::vector& output, Fn fn, @@ -549,13 +549,13 @@ class ProcessGroupNCCL : public ProcessGroup { // primitives. It is the same structure as the helper used for collective // communicaiton primitives. template - std::shared_ptr pointToPoint( + c10::intrusive_ptr pointToPoint( std::vector& tensor, Fn fn, int peer, OpType opType); template - std::shared_ptr pointToPoint( + c10::intrusive_ptr pointToPoint( std::vector& tensor, Fn fn, int peer, @@ -664,7 +664,7 @@ class ProcessGroupNCCL : public ProcessGroup { std::list workMetaList_; // Add Work Pointer to workVector - void workEnqueue(std::shared_ptr); + void workEnqueue(c10::intrusive_ptr); // The CUDA steams used by NCCL kernels std::unordered_map> diff --git a/torch/lib/c10d/ProcessGroupRoundRobin.cpp b/torch/lib/c10d/ProcessGroupRoundRobin.cpp index 032f63c320f5..c77188577a62 100644 --- a/torch/lib/c10d/ProcessGroupRoundRobin.cpp +++ b/torch/lib/c10d/ProcessGroupRoundRobin.cpp @@ -17,66 +17,66 @@ ProcessGroupRoundRobin::ProcessGroupRoundRobin( ProcessGroupRoundRobin::~ProcessGroupRoundRobin() {} -std::shared_ptr ProcessGroupRoundRobin::broadcast( +c10::intrusive_ptr ProcessGroupRoundRobin::broadcast( std::vector& tensors, const BroadcastOptions& opts) { return next()->broadcast(tensors, opts); } -std::shared_ptr ProcessGroupRoundRobin::allreduce( +c10::intrusive_ptr ProcessGroupRoundRobin::allreduce( std::vector& tensors, const AllreduceOptions& opts) { return next()->allreduce(tensors, opts); } -std::shared_ptr ProcessGroupRoundRobin::allreduce_coalesced( +c10::intrusive_ptr ProcessGroupRoundRobin::allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts) { return next()->allreduce_coalesced(tensors, opts); } -std::shared_ptr ProcessGroupRoundRobin::reduce( +c10::intrusive_ptr ProcessGroupRoundRobin::reduce( std::vector& tensors, const ReduceOptions& opts) { return next()->reduce(tensors, opts); } -std::shared_ptr ProcessGroupRoundRobin::allgather( +c10::intrusive_ptr ProcessGroupRoundRobin::allgather( std::vector>& outputs, std::vector& inputs, const AllgatherOptions& opts) { return next()->allgather(outputs, inputs, opts); }; -std::shared_ptr ProcessGroupRoundRobin::allgather_coalesced( +c10::intrusive_ptr ProcessGroupRoundRobin::allgather_coalesced( std::vector>& outputTensorLists, std::vector& inputTensors, const AllgatherOptions& opts) { return next()->allgather(outputTensorLists, inputTensors, opts); } -std::shared_ptr ProcessGroupRoundRobin::gather( +c10::intrusive_ptr ProcessGroupRoundRobin::gather( std::vector>& outputs, std::vector& inputs, const GatherOptions& opts) { return next()->gather(outputs, inputs, opts); }; -std::shared_ptr ProcessGroupRoundRobin::scatter( +c10::intrusive_ptr ProcessGroupRoundRobin::scatter( std::vector& outputs, std::vector>& inputs, const ScatterOptions& opts) { return next()->scatter(outputs, inputs, opts); }; -std::shared_ptr ProcessGroupRoundRobin::reduce_scatter( +c10::intrusive_ptr ProcessGroupRoundRobin::reduce_scatter( std::vector& outputs, std::vector>& inputs, const ReduceScatterOptions& opts) { return next()->reduce_scatter(outputs, inputs, opts); }; -std::shared_ptr ProcessGroupRoundRobin::alltoall_base( +c10::intrusive_ptr ProcessGroupRoundRobin::alltoall_base( at::Tensor& outputTensor, at::Tensor& inputTensor, std::vector& outputSplitSizes, @@ -86,27 +86,27 @@ std::shared_ptr ProcessGroupRoundRobin::alltoall_base( outputTensor, inputTensor, outputSplitSizes, inputSplitSizes, opts); }; -std::shared_ptr ProcessGroupRoundRobin::send( +c10::intrusive_ptr ProcessGroupRoundRobin::send( std::vector& /* unused */, int /* unused */, int /* unused */) { throw std::runtime_error("ProcessGroupRoundRobin does not support send"); }; -std::shared_ptr ProcessGroupRoundRobin::recv( +c10::intrusive_ptr ProcessGroupRoundRobin::recv( std::vector& /* unused */, int /* unused */, int /* unused */) { throw std::runtime_error("ProcessGroupRoundRobin does not support recv"); }; -std::shared_ptr ProcessGroupRoundRobin::recvAnysource( +c10::intrusive_ptr ProcessGroupRoundRobin::recvAnysource( std::vector& /* unused */, int /* unused */) { throw std::runtime_error("ProcessGroupRoundRobin does not support recv"); }; -std::shared_ptr ProcessGroupRoundRobin::barrier( +c10::intrusive_ptr ProcessGroupRoundRobin::barrier( const BarrierOptions& /* unused */) { throw std::runtime_error("ProcessGroupRoundRobin does not support barrier"); }; @@ -120,7 +120,7 @@ const std::shared_ptr& ProcessGroupRoundRobin::next() { return processGroup; } -std::shared_ptr ProcessGroupRoundRobin::allgather_base( +c10::intrusive_ptr ProcessGroupRoundRobin::allgather_base( at::Tensor& /*unused */, at::Tensor& /*unused */, const AllgatherOptions& /*unused */) { diff --git a/torch/lib/c10d/ProcessGroupRoundRobin.hpp b/torch/lib/c10d/ProcessGroupRoundRobin.hpp index bbbd0a1c756b..62d59ef18ce5 100644 --- a/torch/lib/c10d/ProcessGroupRoundRobin.hpp +++ b/torch/lib/c10d/ProcessGroupRoundRobin.hpp @@ -25,75 +25,75 @@ class ProcessGroupRoundRobin final : public ProcessGroup { ~ProcessGroupRoundRobin() override; - std::shared_ptr broadcast( + c10::intrusive_ptr broadcast( std::vector& tensors, const BroadcastOptions& opts = BroadcastOptions()) override; - std::shared_ptr allreduce( + c10::intrusive_ptr allreduce( std::vector& tensors, const AllreduceOptions& opts = AllreduceOptions()) override; - std::shared_ptr allreduce_coalesced( + c10::intrusive_ptr allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts = AllreduceCoalescedOptions()) override; - std::shared_ptr reduce( + c10::intrusive_ptr reduce( std::vector& tensors, const ReduceOptions& opts = ReduceOptions()) override; - std::shared_ptr allgather( + c10::intrusive_ptr allgather( std::vector>& outputs, std::vector& inputs, const AllgatherOptions& opts = AllgatherOptions()) override; - std::shared_ptr allgather_base( + c10::intrusive_ptr allgather_base( at::Tensor& outputBuffer, at::Tensor& inputBuffer, const AllgatherOptions& opts = AllgatherOptions()) override; - std::shared_ptr allgather_coalesced( + c10::intrusive_ptr allgather_coalesced( std::vector>& outputTensorLists, std::vector& inputTensors, const AllgatherOptions& opts = AllgatherOptions()) override; - std::shared_ptr gather( + c10::intrusive_ptr gather( std::vector>& outputs, std::vector& inputs, const GatherOptions& opts = GatherOptions()) override; - std::shared_ptr scatter( + c10::intrusive_ptr scatter( std::vector& outputs, std::vector>& inputs, const ScatterOptions& opts = ScatterOptions()) override; - std::shared_ptr reduce_scatter( + c10::intrusive_ptr reduce_scatter( std::vector& outputs, std::vector>& inputs, const ReduceScatterOptions& opts = ReduceScatterOptions()) override; - std::shared_ptr alltoall_base( + c10::intrusive_ptr alltoall_base( at::Tensor& outputTensor, at::Tensor& inputTensor, std::vector& outputSplitSizes, std::vector& inputSplitSizes, const AllToAllOptions& opts = AllToAllOptions()) override; - std::shared_ptr send( + c10::intrusive_ptr send( std::vector& tensors, int dstRank, int tag) override; - std::shared_ptr recv( + c10::intrusive_ptr recv( std::vector& tensors, int srcRank, int tag) override; - std::shared_ptr recvAnysource( + c10::intrusive_ptr recvAnysource( std::vector& tensors, int tag) override; - std::shared_ptr barrier( + c10::intrusive_ptr barrier( const BarrierOptions& opts = BarrierOptions()) override; private: diff --git a/torch/lib/c10d/comm.cpp b/torch/lib/c10d/comm.cpp index a8628e0c942e..5ef88f058aca 100644 --- a/torch/lib/c10d/comm.cpp +++ b/torch/lib/c10d/comm.cpp @@ -45,8 +45,10 @@ class BroadcastWork { // because c10d::ProcessGroup::broadcast takes a vector argument. std::vector flat_tensor_; + private: + // The broadcast work that is kicked off upon construction. - std::shared_ptr work_; + c10::intrusive_ptr work_; }; } // namespace diff --git a/torch/lib/c10d/example/allreduce.cpp b/torch/lib/c10d/example/allreduce.cpp index 76d6a5588f7e..3de7447d092a 100644 --- a/torch/lib/c10d/example/allreduce.cpp +++ b/torch/lib/c10d/example/allreduce.cpp @@ -19,7 +19,7 @@ int main(int argc, char** argv) { } // Kick off work - std::vector> pending; + std::vector> pending; for (auto i = 0; i < ntensors; i++) { std::vector tmp = {tensors[i]}; pending.push_back(pg.allreduce(tmp)); diff --git a/torch/lib/c10d/reducer.cpp b/torch/lib/c10d/reducer.cpp index c05ce685bb7d..c5ee54a9ee8e 100644 --- a/torch/lib/c10d/reducer.cpp +++ b/torch/lib/c10d/reducer.cpp @@ -472,7 +472,7 @@ std::vector> Reducer::get_bucket_tensors() const { } void Reducer::set_forward_pass_work_handle( - std::shared_ptr forwardPassWorkHandle, + c10::intrusive_ptr forwardPassWorkHandle, bool useStaticWorldSize) { std::lock_guard lock(mutex_); forwardPassWorkHandle_.workHandle = std::move(forwardPassWorkHandle); diff --git a/torch/lib/c10d/reducer.hpp b/torch/lib/c10d/reducer.hpp index 4874f0dd8703..e0fe0004f88e 100644 --- a/torch/lib/c10d/reducer.hpp +++ b/torch/lib/c10d/reducer.hpp @@ -8,6 +8,7 @@ #include #include +#include #include #include #include @@ -96,7 +97,7 @@ class Reducer { // Creates and sets ForwardPassWorkHandle given a ProcessGroup::Work and the // corresponding tensor being reduced. void set_forward_pass_work_handle( - std::shared_ptr forwardPassWorkHandle, + c10::intrusive_ptr forwardPassWorkHandle, bool useStaticWorldSize); // Retrieve on-device tensors used to track locally unused parameters. For @@ -158,7 +159,7 @@ class Reducer { bool local_used_maps_reduced_; // Work handle for allreduce on local_used_maps_ - std::shared_ptr local_used_work_; + c10::intrusive_ptr local_used_work_; void verify_replicas_within_process(); @@ -282,7 +283,7 @@ class Reducer { size_t pending; // Keep work handle around when this set of buckets is being reduced. - std::shared_ptr work; + c10::intrusive_ptr work; // Keep future work handle around if DDP comm hook is registered. c10::intrusive_ptr future_work; @@ -340,7 +341,7 @@ class Reducer { // A struct containing work handle and tensor for allreduce scheduled in // forward pass, if applicable. struct ForwardPassAllreduceWork { - std::shared_ptr workHandle; + c10::intrusive_ptr workHandle; at::Tensor resultTensor; // whether we should divide by the initial world_size or the no. of // remaining DDP ranks. diff --git a/torch/lib/c10d/test/ProcessGroupGlooAsyncTest.cpp b/torch/lib/c10d/test/ProcessGroupGlooAsyncTest.cpp index 92dede9a573e..1363a842eab3 100644 --- a/torch/lib/c10d/test/ProcessGroupGlooAsyncTest.cpp +++ b/torch/lib/c10d/test/ProcessGroupGlooAsyncTest.cpp @@ -93,7 +93,7 @@ class AsyncInputIsOutputTest : public AsyncTest { } } - void wait(std::shared_ptr& work) { + void wait(c10::intrusive_ptr& work) { at::cuda::CUDAMultiStreamGuard guard(streams_); work->wait(); } @@ -130,7 +130,7 @@ class AsyncAllreduceTest : public AsyncInputIsOutputTest { AsyncAllreduceTest(const std::string& path, int numTensors) : AsyncInputIsOutputTest(path, numTensors) {} - std::shared_ptr run() { + c10::intrusive_ptr run() { // For the duration of this function, make THC use our streams at::cuda::CUDAMultiStreamGuard guard(streams_); @@ -156,7 +156,7 @@ class AsyncBroadcastTest : public AsyncInputIsOutputTest { AsyncBroadcastTest(const std::string& path, int numTensors) : AsyncInputIsOutputTest(path, numTensors) {} - std::shared_ptr run(int rootRank, int rootTensor) { + c10::intrusive_ptr run(int rootRank, int rootTensor) { // For the duration of this function, make THC use our streams at::cuda::CUDAMultiStreamGuard guard(streams_); @@ -185,7 +185,7 @@ void runAsyncAllreduceTest( size_t numProcesses = 4, size_t numTensors = 2) { auto tests = initialize(path, numProcesses, numTensors); - std::vector> work(numProcesses); + std::vector> work(numProcesses); for (size_t i = 0; i < numProcesses; i++) { work[i] = tests[i].run(); } @@ -229,7 +229,7 @@ void runAsyncBroadcastTest( // Try every permutation of root rank and root tensor for (size_t rootRank = 0; rootRank < numProcesses; rootRank++) { for (size_t rootTensor = 0; rootTensor < numTensors; rootTensor++) { - std::vector> work(numProcesses); + std::vector> work(numProcesses); for (size_t i = 0; i < numProcesses; i++) { work[i] = tests[i].run(rootRank, rootTensor); } diff --git a/torch/lib/c10d/test/ProcessGroupGlooTest.cpp b/torch/lib/c10d/test/ProcessGroupGlooTest.cpp index da4f9b5fc106..de993a1110b4 100644 --- a/torch/lib/c10d/test/ProcessGroupGlooTest.cpp +++ b/torch/lib/c10d/test/ProcessGroupGlooTest.cpp @@ -44,7 +44,7 @@ class SignalTest { }); } - std::shared_ptr<::c10d::ProcessGroup::Work> run(int rank, int size) { + c10::intrusive_ptr<::c10d::ProcessGroup::Work> run(int rank, int size) { auto store = std::make_shared<::c10d::FileStore>(path_, size); ::c10d::ProcessGroupGloo::Options options; @@ -62,7 +62,7 @@ class SignalTest { }; // Loop until an exception happens - std::shared_ptr<::c10d::ProcessGroup::Work> work; + c10::intrusive_ptr<::c10d::ProcessGroup::Work> work; while (true) { work = pg.allreduce(tensors); try { @@ -82,7 +82,7 @@ class SignalTest { Semaphore sem_; }; -std::shared_ptr<::c10d::ProcessGroup::Work> testSignal( +c10::intrusive_ptr<::c10d::ProcessGroup::Work> testSignal( const std::string& path, int signal) { Fork fork; @@ -107,7 +107,7 @@ class ProcessGroupGlooDelayed : public ::c10d::ProcessGroupGloo { Options options) : ProcessGroupGloo(store, rank, size, options) {} - std::shared_ptr<::c10d::ProcessGroup::Work> send( + c10::intrusive_ptr<::c10d::ProcessGroup::Work> send( std::vector& tensors, int dstRank, int tag) override { @@ -200,7 +200,7 @@ void testAllreduce(const std::string& path, const at::DeviceType b) { } // Kick off work - std::vector> work(size); + std::vector> work(size); for (auto i = 0; i < size; i++) { work[i] = tests[i].getProcessGroup().allreduce(inputs[i]); } @@ -250,7 +250,7 @@ void testBroadcast(const std::string& path, const at::DeviceType b) { options.rootTensor = j; // Kick off work - std::vector> work(size); + std::vector> work(size); for (auto i = 0; i < size; i++) { work[i] = tests[i].getProcessGroup().broadcast(inputs[i], options); } @@ -316,7 +316,7 @@ void testAlltoall(const std::string& path, const at::DeviceType b) { }; // Kick off work - std::vector> work(size); + std::vector> work(size); for (auto rank = 0; rank < size; rank++) { work[rank] = tests[rank].getProcessGroup().alltoall_base( outputs[rank], inputs[rank], outputSplits[rank], inputSplits[rank]); @@ -349,7 +349,7 @@ void testBarrier(const std::string& path) { auto tests = CollectiveTest::initialize(path, size); // Kick off work - std::vector> work(size); + std::vector> work(size); for (auto i = 0; i < size; i++) { work[i] = tests[i].getProcessGroup().barrier(); } diff --git a/torch/lib/c10d/test/ProcessGroupMPITest.cpp b/torch/lib/c10d/test/ProcessGroupMPITest.cpp index 3f5a9e4cf331..6c60b3d6742d 100644 --- a/torch/lib/c10d/test/ProcessGroupMPITest.cpp +++ b/torch/lib/c10d/test/ProcessGroupMPITest.cpp @@ -14,7 +14,7 @@ // Wait for work to complete void waitWork( std::shared_ptr pg, - std::vector> works) { + std::vector> works) { for (auto& work : works) { try { work->wait(); @@ -34,10 +34,11 @@ void testAllreduce(int iter = 1000) { allTensors[i] = std::vector({tensor}); } - std::vector> works; + std::vector> works; for (auto& tensors : allTensors) { // Kick off work - std::shared_ptr<::c10d::ProcessGroup::Work> work = pg->allreduce(tensors); + c10::intrusive_ptr<::c10d::ProcessGroup::Work> work = + pg->allreduce(tensors); works.push_back(std::move(work)); } @@ -73,10 +74,11 @@ void testBroadcast(int iter = 10000) { } } - std::vector> works; + std::vector> works; for (auto& tensors : allTensors) { // Kick off work - std::shared_ptr<::c10d::ProcessGroup::Work> work = pg->broadcast(tensors); + c10::intrusive_ptr<::c10d::ProcessGroup::Work> work = + pg->broadcast(tensors); works.push_back(std::move(work)); } @@ -104,10 +106,10 @@ void testReduce(int iter = 10000) { allTensors[i] = std::vector({tensor}); } - std::vector> works; + std::vector> works; for (auto& tensors : allTensors) { // Kick off work - std::shared_ptr<::c10d::ProcessGroup::Work> work = pg->reduce(tensors); + c10::intrusive_ptr<::c10d::ProcessGroup::Work> work = pg->reduce(tensors); works.push_back(std::move(work)); } @@ -150,10 +152,10 @@ void testAllgather(int iter = 10000) { } } - std::vector> works; + std::vector> works; for (size_t i = 0; i < allTensors.size(); ++i) { // Kick off work - std::shared_ptr<::c10d::ProcessGroup::Work> work = + c10::intrusive_ptr<::c10d::ProcessGroup::Work> work = pg->allgather(allOutputTensors[i], allTensors[i]); works.push_back(std::move(work)); } @@ -198,10 +200,10 @@ void testGather(int iter = 10000) { } } - std::vector> works; + std::vector> works; for (size_t i = 0; i < allTensors.size(); ++i) { // Kick off work - std::shared_ptr<::c10d::ProcessGroup::Work> work = + c10::intrusive_ptr<::c10d::ProcessGroup::Work> work = pg->gather(allOutputTensors[i], allTensors[i]); works.push_back(std::move(work)); } @@ -249,10 +251,10 @@ void testScatter(int iter = 1) { } } - std::vector> works; + std::vector> works; for (size_t i = 0; i < allTensors.size(); ++i) { // Kick off work - std::shared_ptr<::c10d::ProcessGroup::Work> work = + c10::intrusive_ptr<::c10d::ProcessGroup::Work> work = pg->scatter(allTensors[i], allInputTensors[i]); works.push_back(std::move(work)); } @@ -289,27 +291,27 @@ void testSendRecv(bool recvAnysource, int iter = 10000) { } if (rank == 0) { - std::vector> works; + std::vector> works; for (auto& tensors : allTensors) { // Kick off work - std::shared_ptr<::c10d::ProcessGroup::Work> work = + c10::intrusive_ptr<::c10d::ProcessGroup::Work> work = pg->send(tensors, 1, 0); works.push_back(std::move(work)); } waitWork(pg, works); } if (rank == 1) { - std::vector> works; + std::vector> works; std::vector srcRanks(allTensors.size(), -1); size_t i = 0; for (auto& tensors : allTensors) { // Kick off work if (!recvAnysource) { - std::shared_ptr<::c10d::ProcessGroup::Work> work = + c10::intrusive_ptr<::c10d::ProcessGroup::Work> work = pg->recv(tensors, 0, 0); works.push_back(std::move(work)); } else { - std::shared_ptr<::c10d::ProcessGroup::Work> work = + c10::intrusive_ptr<::c10d::ProcessGroup::Work> work = pg->recvAnysource(tensors, 0); works.push_back(std::move(work)); } diff --git a/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp b/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp index e906702a889d..f1348922e126 100644 --- a/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp +++ b/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp @@ -56,12 +56,12 @@ class ProcessGroupNCCLSimulateErrors : public c10d::ProcessGroupNCCL { ProcessGroupNCCLSimulateErrors::kWatchdogThreadSleepMillis); } - std::shared_ptr initWork( + c10::intrusive_ptr initWork( std::vector devices, int rank, c10d::OpType opType, const char* profilingTitle) override { - return std::make_shared( + return c10::make_intrusive( devices, simulate_error_, rank, opType); } @@ -113,12 +113,12 @@ class ProcessGroupNCCLTimedOutErrors : public ProcessGroupNCCLSimulateErrors { : ProcessGroupNCCLSimulateErrors(store, rank, size, opts), set_timedout_error_(false) {} - std::shared_ptr initWork( + c10::intrusive_ptr initWork( std::vector devices, int rank, c10d::OpType opType, const char* profilingTitle) override { - return std::make_shared( + return c10::make_intrusive( devices, set_timedout_error_, rank, opType); } diff --git a/torch/lib/c10d/test/ProcessGroupNCCLTest.cpp b/torch/lib/c10d/test/ProcessGroupNCCLTest.cpp index 92b477fae7de..efa96312aba0 100644 --- a/torch/lib/c10d/test/ProcessGroupNCCLTest.cpp +++ b/torch/lib/c10d/test/ProcessGroupNCCLTest.cpp @@ -80,7 +80,7 @@ class NCCLTest : public NCCLTestBase { } void wait( - std::shared_ptr& work, + c10::intrusive_ptr& work, std::chrono::milliseconds timeout = kNoTimeout) { at::cuda::CUDAMultiStreamGuard guard(streams_); work->wait(timeout); @@ -166,7 +166,7 @@ class AllreduceNCCLTest : public NCCLTest { AllreduceNCCLTest(const std::string& path, int worldSize) : NCCLTest(path, worldSize) {} - std::shared_ptr run() { + c10::intrusive_ptr run() { // For the duration of this function, make THC use our streams at::cuda::CUDAMultiStreamGuard guard(streams_); @@ -189,7 +189,7 @@ class BroadcastNCCLTest : public NCCLTest { BroadcastNCCLTest(const std::string& path, int worldSize) : NCCLTest(path, worldSize) {} - std::shared_ptr run(int rootRank, int rootTensor) { + c10::intrusive_ptr run(int rootRank, int rootTensor) { // For the duration of this function, make THC use our streams at::cuda::CUDAMultiStreamGuard guard(streams_); @@ -208,7 +208,7 @@ class ReduceNCCLTest : public NCCLTest { ReduceNCCLTest(const std::string& path, int worldSize) : NCCLTest(path, worldSize) {} - std::shared_ptr run(int rootRank, int rootTensor) { + c10::intrusive_ptr run(int rootRank, int rootTensor) { // For the duration of this function, make THC use our streams at::cuda::CUDAMultiStreamGuard guard(streams_); @@ -227,7 +227,7 @@ class AllgatherNCCLTest : public NCCLTest { AllgatherNCCLTest(const std::string& path, int worldSize) : NCCLTest(path, worldSize) {} - std::shared_ptr run() { + c10::intrusive_ptr run() { // For the duration of this function, make THC use our streams at::cuda::CUDAMultiStreamGuard guard(streams_); @@ -242,7 +242,7 @@ struct ReduceScatterNCCLTest : NCCLTest { ReduceScatterNCCLTest(const std::string& path, int worldSize) : NCCLTest(path, worldSize) {} - std::shared_ptr run() { + c10::intrusive_ptr run() { // For the duration of this function, make THC use our streams at::cuda::CUDAMultiStreamGuard guard(streams_); From 665ac2f7b05685d1115252dcff29398c8379ec26 Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Wed, 11 Nov 2020 22:49:06 -0800 Subject: [PATCH 43/93] [reland] [c10d] switch Store to be managed by intrusive_ptr (#47808) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47808 reland https://github.com/pytorch/pytorch/pull/47074 Test Plan: wait for ci Reviewed By: gmagogsfm Differential Revision: D24905246 fbshipit-source-id: edeb7e6e486570ce889f12512e9dc02061d6cc03 --- test/cpp/rpc/e2e_test_base.h | 4 ++-- test/cpp_extensions/cpp_c10d_extension.cpp | 2 +- test/cpp_extensions/cpp_c10d_extension.hpp | 2 +- torch/csrc/distributed/c10d/init.cpp | 22 +++++++++---------- torch/csrc/distributed/rpc/init.cpp | 2 +- .../csrc/distributed/rpc/tensorpipe_agent.cpp | 2 +- torch/csrc/distributed/rpc/tensorpipe_agent.h | 2 +- torch/lib/c10d/PrefixStore.cpp | 2 +- torch/lib/c10d/PrefixStore.hpp | 6 +++-- torch/lib/c10d/ProcessGroupGloo.cpp | 6 ++--- torch/lib/c10d/ProcessGroupGloo.hpp | 2 +- torch/lib/c10d/ProcessGroupNCCL.cpp | 2 +- torch/lib/c10d/ProcessGroupNCCL.hpp | 8 +++---- torch/lib/c10d/Store.hpp | 4 +++- torch/lib/c10d/frontend.hpp | 4 ++-- torch/lib/c10d/test/FileStoreTest.cpp | 9 ++++---- torch/lib/c10d/test/HashStoreTest.cpp | 6 ++--- .../c10d/test/ProcessGroupGlooAsyncTest.cpp | 2 +- torch/lib/c10d/test/ProcessGroupGlooTest.cpp | 6 ++--- .../c10d/test/ProcessGroupNCCLErrorsTest.cpp | 8 +++---- torch/lib/c10d/test/ProcessGroupNCCLTest.cpp | 2 +- torch/lib/c10d/test/TCPStoreTest.cpp | 14 ++++++------ 22 files changed, 61 insertions(+), 56 deletions(-) diff --git a/test/cpp/rpc/e2e_test_base.h b/test/cpp/rpc/e2e_test_base.h index 9d3ab71c0cfc..114284839858 100644 --- a/test/cpp/rpc/e2e_test_base.h +++ b/test/cpp/rpc/e2e_test_base.h @@ -28,7 +28,7 @@ class TestE2EBase : public ::testing::Test { autogradContainer = getDistAutogradContainer(); // Setup server store. - store = std::make_shared( + store = c10::make_intrusive( serverAddress, 0, numWorkers, true, std::chrono::seconds(10)); buildRpcAgent(); @@ -147,7 +147,7 @@ class TestE2EBase : public ::testing::Test { std::shared_ptr rpcAgent; static const size_t numIters; static const size_t numWorkers; - std::shared_ptr store; + c10::intrusive_ptr store; static const char* serverAddress; }; diff --git a/test/cpp_extensions/cpp_c10d_extension.cpp b/test/cpp_extensions/cpp_c10d_extension.cpp index 50e5f5861caa..d5ba55a6379c 100644 --- a/test/cpp_extensions/cpp_c10d_extension.cpp +++ b/test/cpp_extensions/cpp_c10d_extension.cpp @@ -108,7 +108,7 @@ c10::intrusive_ptr ProcessGroupTest::recvAnysource( } std::shared_ptr ProcessGroupTest::createProcessGroupTest( - const std::shared_ptr<::c10d::Store>& store, + const c10::intrusive_ptr<::c10d::Store>& store, int rank, int size, const std::chrono::duration& timeout) { diff --git a/test/cpp_extensions/cpp_c10d_extension.hpp b/test/cpp_extensions/cpp_c10d_extension.hpp index 8aeec736d440..1773953629d5 100644 --- a/test/cpp_extensions/cpp_c10d_extension.hpp +++ b/test/cpp_extensions/cpp_c10d_extension.hpp @@ -101,7 +101,7 @@ class ProcessGroupTest : public ProcessGroup { // Create a new ProcessGroupTest instance static std::shared_ptr createProcessGroupTest( - const std::shared_ptr<::c10d::Store>& store, + const c10::intrusive_ptr<::c10d::Store>& store, int rank, int size, const std::chrono::duration& timeout); diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 136efd32fc87..e9d8f618eb21 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -342,7 +342,7 @@ They are used in specifying strategies for reduction collectives, e.g., .def_readwrite("timeout", &::c10d::AllToAllOptions::timeout); auto store = - py::class_<::c10d::Store, std::shared_ptr<::c10d::Store>, PythonStore>( + py::class_<::c10d::Store, c10::intrusive_ptr<::c10d::Store>, PythonStore>( module, "Store", R"( @@ -546,7 +546,7 @@ Example:: >>> store.wait(["bad_key"], timedelta(seconds=10)) )"); - shared_ptr_class_<::c10d::FileStore>( + intrusive_ptr_class_<::c10d::FileStore>( module, "FileStore", store, @@ -569,7 +569,7 @@ Example:: .def(py::init()); #ifndef _WIN32 - shared_ptr_class_<::c10d::HashStore>( + intrusive_ptr_class_<::c10d::HashStore>( module, "HashStore", store, @@ -586,7 +586,7 @@ Example:: )") .def(py::init<>()); - shared_ptr_class_<::c10d::TCPStore>( + intrusive_ptr_class_<::c10d::TCPStore>( module, "TCPStore", store, @@ -626,7 +626,7 @@ Example:: std::chrono::milliseconds(::c10d::Store::kDefaultTimeout)); #endif - shared_ptr_class_<::c10d::PrefixStore>( + intrusive_ptr_class_<::c10d::PrefixStore>( module, "PrefixStore", store, @@ -639,7 +639,7 @@ that adds a prefix to each key inserted to the store. prefix (str): The prefix string that is prepended to each key before being inserted into the store. store (torch.distributed.store): A store object that forms the underlying key-value store. )") - .def(py::init>()); + .def(py::init>()); auto processGroup = shared_ptr_class_<::c10d::ProcessGroup>(module, "ProcessGroup") @@ -952,13 +952,13 @@ that adds a prefix to each key inserted to the store. processGroupGloo .def( py::init< - const std::shared_ptr<::c10d::Store>&, + const c10::intrusive_ptr<::c10d::Store>&, int, int, ::c10d::ProcessGroupGloo::Options>(), py::call_guard()) .def( - py::init([](const std::shared_ptr<::c10d::Store>& store, + py::init([](const c10::intrusive_ptr<::c10d::Store>& store, int rank, int size, std::chrono::milliseconds timeout) { @@ -997,13 +997,13 @@ that adds a prefix to each key inserted to the store. module, "ProcessGroupNCCL", processGroup) .def( py::init< - const std::shared_ptr<::c10d::Store>&, + const c10::intrusive_ptr<::c10d::Store>&, int, int, ::c10d::ProcessGroupNCCL::Options>(), py::call_guard()) .def( - py::init([](const std::shared_ptr<::c10d::Store>& store, + py::init([](const c10::intrusive_ptr<::c10d::Store>& store, int rank, int size, const std::chrono::milliseconds& timeout) { @@ -1168,7 +1168,7 @@ that adds a prefix to each key inserted to the store. // Python side of the world. Calling Python functions on a Python object // completely bypasses pybind11. We need to test that the overloaded // functions call into Python and behave like we expect. - [](std::shared_ptr<::c10d::Store> store) { + [](c10::intrusive_ptr<::c10d::Store> store) { auto add = [&store](const std::string& key, int64_t value) { store->add(key, value); }; diff --git a/torch/csrc/distributed/rpc/init.cpp b/torch/csrc/distributed/rpc/init.cpp index 1d82a619ed7e..81af4abebd5f 100644 --- a/torch/csrc/distributed/rpc/init.cpp +++ b/torch/csrc/distributed/rpc/init.cpp @@ -576,7 +576,7 @@ PyObject* rpc_init(PyObject* _unused, PyObject* noargs) { shared_ptr_class_(module, "TensorPipeAgent", rpcAgent) .def( - py::init([](const std::shared_ptr<::c10d::Store>& store, + py::init([](const c10::intrusive_ptr<::c10d::Store>& store, std::string selfName, worker_id_t selfId, int worldSize, diff --git a/torch/csrc/distributed/rpc/tensorpipe_agent.cpp b/torch/csrc/distributed/rpc/tensorpipe_agent.cpp index 6bf65f4c2628..eff1e7ebdf21 100644 --- a/torch/csrc/distributed/rpc/tensorpipe_agent.cpp +++ b/torch/csrc/distributed/rpc/tensorpipe_agent.cpp @@ -220,7 +220,7 @@ void TensorPipeAgent::collectNames() { } TensorPipeAgent::TensorPipeAgent( - const std::shared_ptr<::c10d::Store>& store, + const c10::intrusive_ptr<::c10d::Store>& store, std::string selfName, worker_id_t selfId, int worldSize, diff --git a/torch/csrc/distributed/rpc/tensorpipe_agent.h b/torch/csrc/distributed/rpc/tensorpipe_agent.h index b4a500de65be..b8c9a8c64e5c 100644 --- a/torch/csrc/distributed/rpc/tensorpipe_agent.h +++ b/torch/csrc/distributed/rpc/tensorpipe_agent.h @@ -141,7 +141,7 @@ struct AggregatedNetworkData { class TensorPipeAgent : public RpcAgent { public: TensorPipeAgent( - const std::shared_ptr<::c10d::Store>& store, + const c10::intrusive_ptr<::c10d::Store>& store, std::string selfName, worker_id_t selfId, int worldSize, diff --git a/torch/lib/c10d/PrefixStore.cpp b/torch/lib/c10d/PrefixStore.cpp index 5f9a3c9c21ec..6f71e422bd0e 100644 --- a/torch/lib/c10d/PrefixStore.cpp +++ b/torch/lib/c10d/PrefixStore.cpp @@ -4,7 +4,7 @@ namespace c10d { PrefixStore::PrefixStore( const std::string& prefix, - std::shared_ptr store) + c10::intrusive_ptr store) : prefix_(prefix), store_(store) {} std::string PrefixStore::joinKey(const std::string& key) { diff --git a/torch/lib/c10d/PrefixStore.hpp b/torch/lib/c10d/PrefixStore.hpp index cad7112fbd76..ec50b3b719bf 100644 --- a/torch/lib/c10d/PrefixStore.hpp +++ b/torch/lib/c10d/PrefixStore.hpp @@ -7,7 +7,9 @@ namespace c10d { class PrefixStore : public Store { public: - explicit PrefixStore(const std::string& prefix, std::shared_ptr store); + explicit PrefixStore( + const std::string& prefix, + c10::intrusive_ptr store); virtual ~PrefixStore(){}; @@ -31,7 +33,7 @@ class PrefixStore : public Store { protected: std::string prefix_; - std::shared_ptr store_; + c10::intrusive_ptr store_; std::string joinKey(const std::string& key); std::vector joinKeys(const std::vector& keys); diff --git a/torch/lib/c10d/ProcessGroupGloo.cpp b/torch/lib/c10d/ProcessGroupGloo.cpp index 90c9b695de28..22da878cce43 100644 --- a/torch/lib/c10d/ProcessGroupGloo.cpp +++ b/torch/lib/c10d/ProcessGroupGloo.cpp @@ -108,7 +108,7 @@ namespace { // Wrap c10d store as Gloo store class GlooStore : public ::gloo::rendezvous::Store { public: - GlooStore(const std::shared_ptr<::c10d::Store>& store) : store_(store) {} + GlooStore(const c10::intrusive_ptr<::c10d::Store>& store) : store_(store) {} void set(const std::string& key, const std::vector& value) override { std::vector tmp(value.begin(), value.end()); @@ -131,7 +131,7 @@ class GlooStore : public ::gloo::rendezvous::Store { } protected: - std::shared_ptr<::c10d::Store> store_; + c10::intrusive_ptr<::c10d::Store> store_; }; typedef void (*ReduceFunc)(void*, const void*, const void*, size_t); @@ -562,7 +562,7 @@ std::shared_ptr<::gloo::transport::Device> ProcessGroupGloo:: #endif ProcessGroupGloo::ProcessGroupGloo( - const std::shared_ptr& store, + const c10::intrusive_ptr& store, int rank, int size, Options options) diff --git a/torch/lib/c10d/ProcessGroupGloo.hpp b/torch/lib/c10d/ProcessGroupGloo.hpp index 74fd0f6e5165..0508b6f857a1 100644 --- a/torch/lib/c10d/ProcessGroupGloo.hpp +++ b/torch/lib/c10d/ProcessGroupGloo.hpp @@ -152,7 +152,7 @@ class ProcessGroupGloo : public ProcessGroup { static std::shared_ptr<::gloo::transport::Device> createDefaultDevice(); explicit ProcessGroupGloo( - const std::shared_ptr& store, + const c10::intrusive_ptr& store, int rank, int size, Options options = Options()); diff --git a/torch/lib/c10d/ProcessGroupNCCL.cpp b/torch/lib/c10d/ProcessGroupNCCL.cpp index d1a4c6cb97ad..89abbc07f930 100644 --- a/torch/lib/c10d/ProcessGroupNCCL.cpp +++ b/torch/lib/c10d/ProcessGroupNCCL.cpp @@ -437,7 +437,7 @@ bool ProcessGroupNCCL::WorkNCCL::timedOut() { } ProcessGroupNCCL::ProcessGroupNCCL( - const std::shared_ptr& store, + const c10::intrusive_ptr& store, int rank, int size, Options options) diff --git a/torch/lib/c10d/ProcessGroupNCCL.hpp b/torch/lib/c10d/ProcessGroupNCCL.hpp index 59f06fda1ec1..b93bd0c2d70c 100644 --- a/torch/lib/c10d/ProcessGroupNCCL.hpp +++ b/torch/lib/c10d/ProcessGroupNCCL.hpp @@ -163,7 +163,7 @@ class ProcessGroupNCCL : public ProcessGroup { // Reference to the store so that we can write aborted communicators // to the store. - std::shared_ptr store_; + c10::intrusive_ptr store_; // Store a reference to NCCL collective's outputs to be used by getFuture. std::shared_ptr> outputs_; @@ -393,7 +393,7 @@ class ProcessGroupNCCL : public ProcessGroup { // communicator. These NCCL communicators are cached and reused if possible. // ProcessGroupNCCL( - const std::shared_ptr& store, + const c10::intrusive_ptr& store, int rank, int size, Options options = Options()); @@ -402,7 +402,7 @@ class ProcessGroupNCCL : public ProcessGroup { // If you have existing code that uses the `groupName`, you can replace // it by specifying a `c10d::PrefixStore(groupName, store)` for store. C10_DEPRECATED ProcessGroupNCCL( - const std::shared_ptr& store, + const c10::intrusive_ptr& store, int rank, int size, const std::string& groupName, @@ -594,7 +594,7 @@ class ProcessGroupNCCL : public ProcessGroup { static const int64_t kWorkCleanupThreadSleepMillis; // The store is used to broadcast the NCCL unique ID of rank 0. - std::shared_ptr store_; + c10::intrusive_ptr store_; // The number of NCCL communicators that have been created during // the lifetime of this process group. This sequence number is diff --git a/torch/lib/c10d/Store.hpp b/torch/lib/c10d/Store.hpp index e42bbf300e0b..f97e80013cdb 100644 --- a/torch/lib/c10d/Store.hpp +++ b/torch/lib/c10d/Store.hpp @@ -6,9 +6,11 @@ #include #include +#include + namespace c10d { -class Store { +class Store : public torch::CustomClassHolder { public: static constexpr std::chrono::milliseconds kDefaultTimeout = std::chrono::seconds(300); diff --git a/torch/lib/c10d/frontend.hpp b/torch/lib/c10d/frontend.hpp index 69705427b53c..3449ee30b5ef 100644 --- a/torch/lib/c10d/frontend.hpp +++ b/torch/lib/c10d/frontend.hpp @@ -35,7 +35,7 @@ class DistributedC10d { const std::chrono::milliseconds& timeout, int64_t world_size, int64_t rank, - std::shared_ptr store, + c10::intrusive_ptr store, const std::string& group_name); void destroyProcessGroup(std::shared_ptr group); @@ -202,7 +202,7 @@ class DistributedC10d { // need to use ProcessGroup or ProcesGroup* as key. std::unordered_map< std::shared_ptr, - std::pair>> + std::pair>> pg_map_; // Note, this is different mapping relationship than original Python diff --git a/torch/lib/c10d/test/FileStoreTest.cpp b/torch/lib/c10d/test/FileStoreTest.cpp index cc8da6326091..ce75c78adce7 100644 --- a/torch/lib/c10d/test/FileStoreTest.cpp +++ b/torch/lib/c10d/test/FileStoreTest.cpp @@ -41,7 +41,7 @@ std::string tmppath() { void testGetSet(std::string path, std::string prefix = "") { // Basic Set/Get on File Store { - auto fileStore = std::make_shared(path, 2); + auto fileStore = c10::make_intrusive(path, 2); c10d::PrefixStore store(prefix, fileStore); c10d::test::set(store, "key0", "value0"); c10d::test::set(store, "key1", "value1"); @@ -53,7 +53,7 @@ void testGetSet(std::string path, std::string prefix = "") { // Perform get on new instance { - auto fileStore = std::make_shared(path, 2); + auto fileStore = c10::make_intrusive(path, 2); c10d::PrefixStore store(prefix, fileStore); c10d::test::check(store, "key0", "value0"); } @@ -69,7 +69,8 @@ void stressTestStore(std::string path, std::string prefix = "") { for (auto i = 0; i < numThreads; i++) { threads.push_back(std::thread([&] { - auto fileStore = std::make_shared(path, numThreads + 1); + auto fileStore = + c10::make_intrusive(path, numThreads + 1); c10d::PrefixStore store(prefix, fileStore); sem1.post(); sem2.wait(); @@ -87,7 +88,7 @@ void stressTestStore(std::string path, std::string prefix = "") { // Check that the counter has the expected value { - auto fileStore = std::make_shared(path, numThreads + 1); + auto fileStore = c10::make_intrusive(path, numThreads + 1); c10d::PrefixStore store(prefix, fileStore); std::string expected = std::to_string(numThreads * numIterations); c10d::test::check(store, "counter", expected); diff --git a/torch/lib/c10d/test/HashStoreTest.cpp b/torch/lib/c10d/test/HashStoreTest.cpp index a16f83231a58..24b7fc76a417 100644 --- a/torch/lib/c10d/test/HashStoreTest.cpp +++ b/torch/lib/c10d/test/HashStoreTest.cpp @@ -11,7 +11,7 @@ void testGetSet(std::string prefix = "") { // Basic set/get { - auto hashStore = std::make_shared(); + auto hashStore = c10::make_intrusive(); c10d::PrefixStore store(prefix, hashStore); c10d::test::set(store, "key0", "value0"); c10d::test::set(store, "key1", "value1"); @@ -32,7 +32,7 @@ void testGetSet(std::string prefix = "") { // get() waits up to timeout_. { - auto hashStore = std::make_shared(); + auto hashStore = c10::make_intrusive(); c10d::PrefixStore store(prefix, hashStore); std::thread th([&]() { c10d::test::set(store, "key0", "value0"); }); c10d::test::check(store, "key0", "value0"); @@ -47,7 +47,7 @@ void stressTestStore(std::string prefix = "") { std::vector threads; c10d::test::Semaphore sem1, sem2; - auto hashStore = std::make_shared(); + auto hashStore = c10::make_intrusive(); c10d::PrefixStore store(prefix, hashStore); for (auto i = 0; i < numThreads; i++) { diff --git a/torch/lib/c10d/test/ProcessGroupGlooAsyncTest.cpp b/torch/lib/c10d/test/ProcessGroupGlooAsyncTest.cpp index 1363a842eab3..091ea9b2ad07 100644 --- a/torch/lib/c10d/test/ProcessGroupGlooAsyncTest.cpp +++ b/torch/lib/c10d/test/ProcessGroupGlooAsyncTest.cpp @@ -45,7 +45,7 @@ class AsyncTest { } void start(int rank, int size) { - auto store = std::make_shared<::c10d::FileStore>(path_, size); + auto store = c10::make_intrusive<::c10d::FileStore>(path_, size); // Use tiny timeout to make this test run fast ::c10d::ProcessGroupGloo::Options options; diff --git a/torch/lib/c10d/test/ProcessGroupGlooTest.cpp b/torch/lib/c10d/test/ProcessGroupGlooTest.cpp index de993a1110b4..469cf32a8442 100644 --- a/torch/lib/c10d/test/ProcessGroupGlooTest.cpp +++ b/torch/lib/c10d/test/ProcessGroupGlooTest.cpp @@ -45,7 +45,7 @@ class SignalTest { } c10::intrusive_ptr<::c10d::ProcessGroup::Work> run(int rank, int size) { - auto store = std::make_shared<::c10d::FileStore>(path_, size); + auto store = c10::make_intrusive<::c10d::FileStore>(path_, size); ::c10d::ProcessGroupGloo::Options options; // Set a timeout that is small enough to make this test run fast, but also @@ -101,7 +101,7 @@ c10::intrusive_ptr<::c10d::ProcessGroup::Work> testSignal( class ProcessGroupGlooDelayed : public ::c10d::ProcessGroupGloo { public: ProcessGroupGlooDelayed( - const std::shared_ptr<::c10d::Store>& store, + const c10::intrusive_ptr<::c10d::Store>& store, int rank, int size, Options options) @@ -151,7 +151,7 @@ class CollectiveTest { } void start(int rank, int size, bool delayed) { - auto store = std::make_shared<::c10d::FileStore>(path_, size); + auto store = c10::make_intrusive<::c10d::FileStore>(path_, size); // Set a timeout that is small enough to make this test run fast, but also // make sure that we don't get timeouts in the ProcessGroupGloo constructor. diff --git a/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp b/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp index f1348922e126..e19981c523de 100644 --- a/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp +++ b/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp @@ -37,7 +37,7 @@ class WorkNCCLSimulateErrors : public c10d::ProcessGroupNCCL::WorkNCCL { class ProcessGroupNCCLSimulateErrors : public c10d::ProcessGroupNCCL { public: ProcessGroupNCCLSimulateErrors( - const std::shared_ptr& store, + const c10::intrusive_ptr& store, int rank, int size, c10d::ProcessGroupNCCL::Options opts) @@ -106,7 +106,7 @@ class WorkNCCLTimedoutErrors : public c10d::ProcessGroupNCCL::WorkNCCL { class ProcessGroupNCCLTimedOutErrors : public ProcessGroupNCCLSimulateErrors { public: ProcessGroupNCCLTimedOutErrors( - const std::shared_ptr& store, + const c10::intrusive_ptr& store, int rank, int size, c10d::ProcessGroupNCCL::Options opts) @@ -153,7 +153,7 @@ class ProcessGroupNCCLErrorsTest : public ::testing::Test { void SetUp() override { size_t numDevices = cudaNumDevices(); TemporaryFile file; - store_ = std::make_shared<::c10d::FileStore>(file.path, 1); + store_ = c10::make_intrusive<::c10d::FileStore>(file.path, 1); at::cuda::OptionalCUDAGuard deviceGuard; tensors_.resize(numDevices); @@ -168,7 +168,7 @@ class ProcessGroupNCCLErrorsTest : public ::testing::Test { } std::vector tensors_; - std::shared_ptr<::c10d::FileStore> store_; + c10::intrusive_ptr<::c10d::FileStore> store_; }; TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsBlocking) { diff --git a/torch/lib/c10d/test/ProcessGroupNCCLTest.cpp b/torch/lib/c10d/test/ProcessGroupNCCLTest.cpp index efa96312aba0..fa5e988273fc 100644 --- a/torch/lib/c10d/test/ProcessGroupNCCLTest.cpp +++ b/torch/lib/c10d/test/ProcessGroupNCCLTest.cpp @@ -31,7 +31,7 @@ class NCCLTestBase { } void initialize(int rank, int size) { - auto store = std::make_shared<::c10d::FileStore>(path_, size); + auto store = c10::make_intrusive<::c10d::FileStore>(path_, size); pg_ = std::unique_ptr<::c10d::ProcessGroupNCCL>( new ::c10d::ProcessGroupNCCL(store, rank, size)); diff --git a/torch/lib/c10d/test/TCPStoreTest.cpp b/torch/lib/c10d/test/TCPStoreTest.cpp index 0cfa72c7801a..8073ec0345e0 100644 --- a/torch/lib/c10d/test/TCPStoreTest.cpp +++ b/torch/lib/c10d/test/TCPStoreTest.cpp @@ -16,7 +16,7 @@ void testHelper(const std::string& prefix = "") { const auto numThreads = 16; const auto numWorkers = numThreads + 1; - auto serverTCPStore = std::make_shared( + auto serverTCPStore = c10::make_intrusive( "127.0.0.1", 0, numWorkers, @@ -25,7 +25,7 @@ void testHelper(const std::string& prefix = "") { /* wait */ false); auto serverStore = - std::make_unique(prefix, serverTCPStore); + c10::make_intrusive(prefix, serverTCPStore); // server store auto serverThread = std::thread([&serverStore, &serverTCPStore] { // Wait for all workers to join. @@ -64,13 +64,13 @@ void testHelper(const std::string& prefix = "") { c10d::test::Semaphore sem1, sem2; // Each thread will have a client store to send/recv data - std::vector> clientTCPStores; - std::vector> clientStores; + std::vector> clientTCPStores; + std::vector> clientStores; for (auto i = 0; i < numThreads; i++) { - clientTCPStores.push_back(std::make_unique( + clientTCPStores.push_back(c10::make_intrusive( "127.0.0.1", serverTCPStore->getPort(), numWorkers, false)); - clientStores.push_back(std::unique_ptr( - new c10d::PrefixStore(prefix, clientTCPStores[i]))); + clientStores.push_back( + c10::make_intrusive(prefix, clientTCPStores[i])); } std::string expectedCounterRes = std::to_string(numThreads * numIterations + 1); From a02baa0c7a30caac6a7fe320672061eb7e86e389 Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Wed, 11 Nov 2020 22:49:06 -0800 Subject: [PATCH 44/93] [reland][c10d] switch ProcessGroupNCCL:Options to be managed by intrusive_ptr (#47807) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47807 reland https://github.com/pytorch/pytorch/pull/47075 Test Plan: wait for ci Reviewed By: gmagogsfm Differential Revision: D24905247 fbshipit-source-id: abd9731d86b3bd48d60bbc90d534823e0c037b93 --- torch/csrc/distributed/c10d/init.cpp | 12 +++++++----- torch/lib/c10d/ProcessGroupNCCL.cpp | 6 +++--- torch/lib/c10d/ProcessGroupNCCL.hpp | 16 +++++++++++++--- .../lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp | 16 ++++++++-------- 4 files changed, 31 insertions(+), 19 deletions(-) diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index e9d8f618eb21..dd32ff91603c 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -1000,16 +1000,17 @@ that adds a prefix to each key inserted to the store. const c10::intrusive_ptr<::c10d::Store>&, int, int, - ::c10d::ProcessGroupNCCL::Options>(), + const c10::intrusive_ptr< + ::c10d::ProcessGroupNCCL::Options>&>(), py::call_guard()) .def( py::init([](const c10::intrusive_ptr<::c10d::Store>& store, int rank, int size, const std::chrono::milliseconds& timeout) { - ::c10d::ProcessGroupNCCL::Options options; - options.isHighPriorityStream = false; - options.opTimeout = timeout; + auto options = ::c10d::ProcessGroupNCCL::Options::create(); + options->isHighPriorityStream = false; + options->opTimeout = timeout; return std::make_shared<::c10d::ProcessGroupNCCL>( store, rank, size, options); }), @@ -1020,7 +1021,8 @@ that adds a prefix to each key inserted to the store. ::c10d::ProcessGroupNCCL::kProcessGroupNCCLOpTimeoutMillis), py::call_guard()); - py::class_<::c10d::ProcessGroupNCCL::Options>(processGroupNCCL, "Options") + intrusive_ptr_class_<::c10d::ProcessGroupNCCL::Options>( + processGroupNCCL, "Options") .def(py::init<>()) .def_readwrite( "is_high_priority", diff --git a/torch/lib/c10d/ProcessGroupNCCL.cpp b/torch/lib/c10d/ProcessGroupNCCL.cpp index 89abbc07f930..59219c07b32f 100644 --- a/torch/lib/c10d/ProcessGroupNCCL.cpp +++ b/torch/lib/c10d/ProcessGroupNCCL.cpp @@ -440,14 +440,14 @@ ProcessGroupNCCL::ProcessGroupNCCL( const c10::intrusive_ptr& store, int rank, int size, - Options options) + const c10::intrusive_ptr& options) : ProcessGroup(rank, size), store_(store), ncclCommCounter_(0), terminateProcessGroup_(false), - opTimeout_(options.opTimeout), + opTimeout_(options->opTimeout), futureNCCLCallbackStreams_(c10::cuda::device_count()), - isHighPriorityStream_(options.isHighPriorityStream) { + isHighPriorityStream_(options->isHighPriorityStream) { TORCH_CHECK(at::cuda::getNumGPUs() != 0, "ProcessGroupNCCL is only supported with GPUs, no GPUs found!"); blockingWait_ = parseEnvVarFlag(NCCL_BLOCKING_WAIT); diff --git a/torch/lib/c10d/ProcessGroupNCCL.hpp b/torch/lib/c10d/ProcessGroupNCCL.hpp index b93bd0c2d70c..b84cc4deb051 100644 --- a/torch/lib/c10d/ProcessGroupNCCL.hpp +++ b/torch/lib/c10d/ProcessGroupNCCL.hpp @@ -1,5 +1,6 @@ #pragma once +#include #include #include #include @@ -17,6 +18,8 @@ #include #include +#include + namespace c10d { // Environment variable which controls whether or not wait() is blocking or @@ -175,9 +178,16 @@ class ProcessGroupNCCL : public ProcessGroup { friend class ProcessGroupNCCL; }; - struct Options { + struct Options : torch::CustomClassHolder { explicit Options(); + // return intrusive_ptr of the object + static c10::intrusive_ptr create( + std::chrono::milliseconds timeout = kNoTimeout, + bool isHighStream = false) { + return c10::make_intrusive(); + } + std::chrono::milliseconds opTimeout; bool isHighPriorityStream; }; @@ -396,7 +406,7 @@ class ProcessGroupNCCL : public ProcessGroup { const c10::intrusive_ptr& store, int rank, int size, - Options options = Options()); + const c10::intrusive_ptr& options = Options::create()); // This constructor includes the deprecated `groupName` argument. // If you have existing code that uses the `groupName`, you can replace @@ -406,7 +416,7 @@ class ProcessGroupNCCL : public ProcessGroup { int rank, int size, const std::string& groupName, - Options options = Options()) + const c10::intrusive_ptr& options = Options::create()) : ProcessGroupNCCL(store, rank, size, options) {} virtual ~ProcessGroupNCCL(); diff --git a/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp b/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp index e19981c523de..82ca25049c63 100644 --- a/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp +++ b/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp @@ -40,7 +40,7 @@ class ProcessGroupNCCLSimulateErrors : public c10d::ProcessGroupNCCL { const c10::intrusive_ptr& store, int rank, int size, - c10d::ProcessGroupNCCL::Options opts) + const c10::intrusive_ptr& opts) : ProcessGroupNCCL(store, rank, size, opts), simulate_error_(false) {} std::exception_ptr checkForNCCLErrors( @@ -109,7 +109,7 @@ class ProcessGroupNCCLTimedOutErrors : public ProcessGroupNCCLSimulateErrors { const c10::intrusive_ptr& store, int rank, int size, - c10d::ProcessGroupNCCL::Options opts) + const c10::intrusive_ptr& opts) : ProcessGroupNCCLSimulateErrors(store, rank, size, opts), set_timedout_error_(false) {} @@ -177,8 +177,8 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsBlocking) { } ASSERT_TRUE(setenv(c10d::NCCL_BLOCKING_WAIT, "1", 1) == 0); - c10d::ProcessGroupNCCL::Options options; - options.opTimeout = std::chrono::milliseconds(1000); + auto options = c10d::ProcessGroupNCCL::Options::create(); + options->opTimeout = std::chrono::milliseconds(1000); ProcessGroupNCCLSimulateErrors pg( store_, 0, 1, options); @@ -206,8 +206,8 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLTimedoutErrorsBlocking) { } ASSERT_TRUE(setenv(c10d::NCCL_BLOCKING_WAIT, "1", 1) == 0); - c10d::ProcessGroupNCCL::Options options; - options.opTimeout = std::chrono::milliseconds(3000); + auto options = c10d::ProcessGroupNCCL::Options::create(); + options->opTimeout = std::chrono::milliseconds(3000); ProcessGroupNCCLTimedOutErrors pg( store_, 0, 1, options); @@ -229,8 +229,8 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNonBlocking) { return; } - c10d::ProcessGroupNCCL::Options options; - options.opTimeout = std::chrono::milliseconds(3000); + auto options = c10d::ProcessGroupNCCL::Options::create(); + options->opTimeout = std::chrono::milliseconds(3000); ProcessGroupNCCLSimulateErrors pg( store_, 0, 1, options); From 4b25d83e9bdaee701967d7aff625cedf4c12913c Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Wed, 11 Nov 2020 22:52:57 -0800 Subject: [PATCH 45/93] torch.dropout: fix non-contiguous layout input (#47552) Summary: Fixes https://github.com/pytorch/pytorch/issues/47176 Pull Request resolved: https://github.com/pytorch/pytorch/pull/47552 Reviewed By: ailzhang Differential Revision: D24903435 Pulled By: ngimel fbshipit-source-id: ef5398931dddf452f5f734b4aa40c11f4ee61664 --- aten/src/ATen/native/cuda/Dropout.cu | 205 ++++++++++++++++----------- test/test_nn.py | 51 +++++++ 2 files changed, 176 insertions(+), 80 deletions(-) diff --git a/aten/src/ATen/native/cuda/Dropout.cu b/aten/src/ATen/native/cuda/Dropout.cu index c417a7ccfabd..79736677debc 100644 --- a/aten/src/ATen/native/cuda/Dropout.cu +++ b/aten/src/ATen/native/cuda/Dropout.cu @@ -102,7 +102,8 @@ template < typename scalar_t, typename accscalar_t, typename IndexType, - int ADims> + int ADims, + int BDims=ADims> #if __CUDA_ARCH__ >= 350 C10_LAUNCH_BOUNDS_2(256, 8) #elif defined (__HIP_PLATFORM_HCC__) @@ -149,7 +150,7 @@ fused_dropout_kernel(cuda::detail::TensorInfo a, if (li < totalElements) { // Convert `linearIndex` into an offset of `b` const IndexType bOffset = - cuda::detail::IndexToOffset::get(li, b); + cuda::detail::IndexToOffset::get(li, b); b.data[bOffset] = src[ii]*(&rand.x)[ii]*pinv; c.data[bOffset] = (uint8_t)(&rand.x)[ii]; } @@ -178,8 +179,7 @@ template int get_vector_size(at::Tensor self, at::Tensor ret, at::Tensor mask) { int vec_size = 4; // get the vector size - auto memory_format = self.suggest_memory_format(); - if (!self.is_contiguous(memory_format) || !ret.is_contiguous(memory_format) || !mask.is_contiguous(memory_format)) { + if (!self.is_non_overlapping_and_dense() || !ret.is_non_overlapping_and_dense() || !mask.is_non_overlapping_and_dense()) { vec_size = 1; } else { vec_size = memory::can_vectorize_up_to((char*)self.data_ptr()); @@ -194,13 +194,128 @@ int get_vector_size(at::Tensor self, at::Tensor ret, at::Tensor mask) { return can_vectorize ? vec_size : 1; } +template +inline void launcher( + const Tensor& self, + Tensor& ret, + Tensor& mask, + double p, + const int64_t nelem, + const std::pair rng_engine_inputs, + dim3 grid, + dim3 dim_block) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + self.scalar_type(), + "fused_dropout", + [&] { + AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "fused_dropout", [&] { + using accscalar_t = acc_type; + accscalar_t pa = (accscalar_t)(p); + auto self_info = + cuda::detail::getTensorInfo(self); + auto ret_info = + cuda::detail::getTensorInfo(ret); + auto mask_info = + cuda::detail::getTensorInfo(mask); + self_info.collapseDims(); + ret_info.collapseDims(); + mask_info.collapseDims(); // ret and mask are collapsed to 1d + // contiguous tensor + + int vec_size = get_vector_size(self, ret, mask); + + if (vec_size > 1) { + switch (vec_size) { + case 4: + fused_dropout_kernel_vec< + scalar_t, + accscalar_t, + index_type, + 1, + 4> + <<>>( + self_info, + ret_info, + mask_info, + nelem, + pa, + rng_engine_inputs); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); + break; + case 2: + fused_dropout_kernel_vec< + scalar_t, + accscalar_t, + index_type, + 1, + 2> + <<>>( + self_info, + ret_info, + mask_info, + nelem, + pa, + rng_engine_inputs); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); + break; + } + } else { + switch (self_info.dims) { + case 1: + fused_dropout_kernel + <<>>( + self_info, + ret_info, + mask_info, + nelem, + pa, + rng_engine_inputs); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); + break; + default: + if (!self.is_contiguous() && ret.is_contiguous() && + mask.is_contiguous()) { + fused_dropout_kernel + <<>>( + self_info, + ret_info, + mask_info, + nelem, + pa, + rng_engine_inputs); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); + } else { + fused_dropout_kernel + <<>>( + self_info, + ret_info, + mask_info, + nelem, + pa, + rng_engine_inputs); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); + } + } + } + }); + }); +} + } //anonymous namespace std::tuple fused_dropout_cuda(const Tensor& self, double p, c10::optional gen_){ auto gen = get_generator_or_default(gen_, cuda::detail::getDefaultCUDAGenerator()); - Tensor ret = at::empty_like(self, self.suggest_memory_format()); - Tensor mask = at::empty(self.sizes(), self.options().dtype(kByte), self.suggest_memory_format()); + Tensor ret = at::empty_like(self); + Tensor mask = at::empty_like(self, self.options().dtype(kByte)); const int64_t nelem = self.numel(); //empty tensors should not get here, but just in case, avoid FPE if (nelem==0) return std::tuple(self, mask); @@ -218,81 +333,11 @@ fused_dropout_cuda(const Tensor& self, double p, c10::optional gen_){ rng_engine_inputs = gen->philox_engine_inputs(counter_offset); } if (cuda::detail::canUse32BitIndexMath(self)){ - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "fused_dropout", [&] { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "fused_dropout", [&] { - using accscalar_t = acc_type; - accscalar_t pa = (accscalar_t)(p); - auto self_info = cuda::detail::getTensorInfo(self); - auto ret_info = cuda::detail::getTensorInfo(ret); - auto mask_info = cuda::detail::getTensorInfo(mask); - self_info.collapseDims(); - ret_info.collapseDims(); - mask_info.collapseDims(); //ret and mask are collapsed to 1d contiguous tensor - - int vec_size = get_vector_size(self, ret, mask); - - if (vec_size > 1) { - switch (vec_size) { - case 4: - fused_dropout_kernel_vec<<>>(self_info, ret_info, mask_info, nelem, pa, rng_engine_inputs); - TORCH_CUDA_KERNEL_LAUNCH_CHECK(); - break; - case 2: - fused_dropout_kernel_vec<<>>(self_info, ret_info, mask_info, nelem, pa, rng_engine_inputs); - TORCH_CUDA_KERNEL_LAUNCH_CHECK(); - break; - } - } else { - switch (self_info.dims) { - case 1: - fused_dropout_kernel<<>>(self_info, ret_info, mask_info, nelem, pa, rng_engine_inputs); - TORCH_CUDA_KERNEL_LAUNCH_CHECK(); - break; - default: - fused_dropout_kernel<<>>(self_info, ret_info, mask_info, nelem, pa, rng_engine_inputs); - TORCH_CUDA_KERNEL_LAUNCH_CHECK(); - } - } - }); - }); + launcher( + self, ret, mask, p, nelem, rng_engine_inputs, grid, dim_block); } else { - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "fused_dropout", [&] { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "fused_dropout", [&] { - using accscalar_t = acc_type; - accscalar_t pa = (accscalar_t)(p); - auto self_info = cuda::detail::getTensorInfo(self); - auto ret_info = cuda::detail::getTensorInfo(ret); - auto mask_info = cuda::detail::getTensorInfo(mask); - self_info.collapseDims(); - ret_info.collapseDims(); - mask_info.collapseDims(); //ret and mask are collapsed to 1d contiguous tensor - - int vec_size = get_vector_size(self, ret, mask); - - if (vec_size > 1) { - switch (vec_size) { - case 4: - fused_dropout_kernel_vec<<>>(self_info, ret_info, mask_info, nelem, pa, rng_engine_inputs); - TORCH_CUDA_KERNEL_LAUNCH_CHECK(); - break; - case 2: - fused_dropout_kernel_vec<<>>(self_info, ret_info, mask_info, nelem, pa, rng_engine_inputs); - TORCH_CUDA_KERNEL_LAUNCH_CHECK(); - break; - } - } else { - switch (self_info.dims) { - case 1: - fused_dropout_kernel<<>>(self_info, ret_info, mask_info, nelem, pa, rng_engine_inputs); - TORCH_CUDA_KERNEL_LAUNCH_CHECK(); - break; - default: - fused_dropout_kernel<<>>(self_info, ret_info, mask_info, nelem, pa, rng_engine_inputs); - TORCH_CUDA_KERNEL_LAUNCH_CHECK(); - } - } - }); - }); + launcher( + self, ret, mask, p, nelem, rng_engine_inputs, grid, dim_block); } return std::tuple(ret, mask); } diff --git a/test/test_nn.py b/test/test_nn.py index 6b8c97db2f52..020b206905d9 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -9721,6 +9721,46 @@ def _test_dropout(self, cls, device, input, memory_format=torch.contiguous_forma module.__repr__() str(module) + def _test_dropout_discontiguous(self, cls, device, memory_format=torch.contiguous_format): + # In this test, we verify that dropout preserves the layout and data for different memory formats. + # We check whether, we get same values for the output of dropout, when the probability + # of dropout is 0 or very close to 0. + # Reference: https://github.com/pytorch/pytorch/issues/47176 + close_to_zero_p = 1e-10 # Should be almost zero but not zero, as for p=0 different path is taken + for p in [0, close_to_zero_p]: + inp = torch.ones(2, 3, 3, 3, device=device) + inp_discontiguous = torch.empty(2, 3, 3, 6, device=device, memory_format=memory_format)[..., ::2] + inp_discontiguous.copy_(inp) + mod = cls(p=p) + out = mod(inp_discontiguous) + if p != 0: # Zero will keep strides as is based on input. + # When prob == 0, input stride (54, 18, 6, 2) -> output stride (54, 18, 6, 2) + # When prob != 0, input stride (54, 18, 6, 2) -> output stride (27, 9, 3, 1) + self.assertTrue(out.is_contiguous(memory_format=memory_format)) + self.assertEqual(inp_discontiguous, out) + + def _test_dropout_stride_mean_preserve(self, cls, device): + def invert_perm(p): + d = {x: i for i, x in enumerate(p)} + return (d[0], d[1], d[2], d[3]) + + inp = torch.ones(2, 3, 4, 5, device=device) + shifts = [(0, 0), (1, 0), (0, 1), (1, 1)] + for perm in itertools.permutations((0, 1, 2, 3), r=4): + for shift in shifts: + for p in [1e-10, 0.3, 0.5, 0.7]: + mod = cls(p=p) + permuted_inp = inp.permute(perm).contiguous().permute(invert_perm(perm)) + permuted_inp = permuted_inp[shift[0]:, shift[1]:, :, :] + out = mod(permuted_inp) + + self.assertTrue(out.permute(perm).is_contiguous()) + self.assertEqual(inp.mean(), out.mean(), rtol=0.5, atol=0.5) + if p == 1e-10: + self.assertEqual(permuted_inp, out) + else: + self.assertNotEqual(permuted_inp, out) + def _test_InstanceNorm_general(self, cls, input, device, dtype=torch.float): # default case track_running_stats=False b, c = input.size(0), input.size(1) @@ -10159,6 +10199,11 @@ def test_Dropout(self, device): input = torch.Tensor(1000) self._test_dropout(nn.Dropout, device, input) + self._test_dropout_discontiguous(nn.Dropout, device) + self._test_dropout_discontiguous(nn.Dropout, device, memory_format=torch.channels_last) + + self._test_dropout_stride_mean_preserve(nn.Dropout, device) + if self.device_type == 'cuda' and TEST_WITH_ROCM: input = input.bfloat16() self._test_dropout(nn.Dropout, device, input) @@ -10172,6 +10217,9 @@ def test_Dropout2d(self, device): self._test_dropout(nn.Dropout2d, device, input) self._test_dropout(nn.Dropout2d, device, input, memory_format=torch.channels_last) + self._test_dropout_discontiguous(nn.Dropout2d, device) + self._test_dropout_discontiguous(nn.Dropout2d, device, memory_format=torch.channels_last) + def test_Dropout3d(self, device): b = random.randint(1, 5) w = random.randint(1, 5) @@ -10181,6 +10229,9 @@ def test_Dropout3d(self, device): input = torch.Tensor(num_features, b, d, w, h) self._test_dropout(nn.Dropout3d, device, input) + self._test_dropout_discontiguous(nn.Dropout3d, device) + self._test_dropout_discontiguous(nn.Dropout3d, device, memory_format=torch.channels_last) + def test_InstanceNorm1d_general(self, device): b = random.randint(3, 5) c = random.randint(3, 5) From 2907447c97c0e7a8673951907a4b94a682c39526 Mon Sep 17 00:00:00 2001 From: ArtistBanda Date: Thu, 12 Nov 2020 00:13:01 -0800 Subject: [PATCH 46/93] Spurious numpy writable warning (#47271) Summary: Fixes https://github.com/pytorch/pytorch/issues/47160 Pull Request resolved: https://github.com/pytorch/pytorch/pull/47271 Reviewed By: ailzhang Differential Revision: D24855889 Pulled By: mruberry fbshipit-source-id: beaf232b115872f20fb0292e995a876cdc429868 --- test/test_tensor_creation_ops.py | 9 +++++++++ torch/csrc/utils/tensor_new.cpp | 2 +- torch/csrc/utils/tensor_numpy.cpp | 8 +++++--- torch/csrc/utils/tensor_numpy.h | 2 +- 4 files changed, 16 insertions(+), 5 deletions(-) diff --git a/test/test_tensor_creation_ops.py b/test/test_tensor_creation_ops.py index 74f0f4c34017..bd1f3bc909eb 100644 --- a/test/test_tensor_creation_ops.py +++ b/test/test_tensor_creation_ops.py @@ -1165,6 +1165,15 @@ def test_full_out(self, device): self.assertEqual(torch.full(o.shape, 1., out=o).dtype, o.dtype) self.assertEqual(torch.full(size, 1, out=o).dtype, o.dtype) + # check that warning for numpy being not writable is suppressed + # when a copy of it is being created. + # see issue #47160 + def test_tensor_from_non_writable_numpy(self, device): + with warnings.catch_warnings(record=True) as w: + a = np.arange(5.) + a.flags.writeable = False + t = torch.tensor(a) + self.assertEqual(len(w), 0) # Class for testing random tensor creation ops, like torch.randint diff --git a/torch/csrc/utils/tensor_new.cpp b/torch/csrc/utils/tensor_new.cpp index e637e45eaade..87472f4cd81f 100644 --- a/torch/csrc/utils/tensor_new.cpp +++ b/torch/csrc/utils/tensor_new.cpp @@ -270,7 +270,7 @@ Tensor internal_new_from_data( if (PyArray_Check(data)) { TORCH_CHECK(!pin_memory, "Can't pin tensor constructed from numpy"); - auto tensor = tensor_from_numpy(data); + auto tensor = tensor_from_numpy(data, /*warn_if_not_writeable=*/!copy_numpy); const auto& inferred_scalar_type = type_inference ? tensor.scalar_type() : scalar_type; auto device = device_opt.has_value() ? *device_opt : at::Device(computeDeviceType(dispatch_key)); pybind11::gil_scoped_release no_gil; diff --git a/torch/csrc/utils/tensor_numpy.cpp b/torch/csrc/utils/tensor_numpy.cpp index 8c17c2ac7492..c2a67f8df06b 100644 --- a/torch/csrc/utils/tensor_numpy.cpp +++ b/torch/csrc/utils/tensor_numpy.cpp @@ -7,7 +7,7 @@ namespace torch { namespace utils { PyObject* tensor_to_numpy(const at::Tensor& tensor) { throw std::runtime_error("PyTorch was compiled without NumPy support"); } -at::Tensor tensor_from_numpy(PyObject* obj) { +at::Tensor tensor_from_numpy(PyObject* obj, bool warn_if_not_writeable/*=true*/) { throw std::runtime_error("PyTorch was compiled without NumPy support"); } bool is_numpy_int(PyObject* obj) { @@ -125,13 +125,15 @@ PyObject* tensor_to_numpy(const at::Tensor& tensor) { return array.release(); } -at::Tensor tensor_from_numpy(PyObject* obj) { +at::Tensor tensor_from_numpy(PyObject* obj, bool warn_if_not_writeable/*=true*/) { if (!PyArray_Check(obj)) { throw TypeError("expected np.ndarray (got %s)", Py_TYPE(obj)->tp_name); } auto array = (PyArrayObject*)obj; - if (!PyArray_ISWRITEABLE(array)) { + // warn_if_not_writable is true when a copy of numpy variable is created. + // the warning is suppressed when a copy is being created. + if (!PyArray_ISWRITEABLE(array) && warn_if_not_writeable) { TORCH_WARN_ONCE( "The given NumPy array is not writeable, and PyTorch does " "not support non-writeable tensors. This means you can write to the " diff --git a/torch/csrc/utils/tensor_numpy.h b/torch/csrc/utils/tensor_numpy.h index f984d6b93a90..c4c93637db54 100644 --- a/torch/csrc/utils/tensor_numpy.h +++ b/torch/csrc/utils/tensor_numpy.h @@ -6,7 +6,7 @@ namespace torch { namespace utils { PyObject* tensor_to_numpy(const at::Tensor& tensor); -at::Tensor tensor_from_numpy(PyObject* obj); +at::Tensor tensor_from_numpy(PyObject* obj, bool warn_if_not_writeable=true); int aten_to_numpy_dtype(const at::ScalarType scalar_type); at::ScalarType numpy_dtype_to_aten(int dtype); From 2df5600155da765ac260f4cf13b1efd9fb48fbb1 Mon Sep 17 00:00:00 2001 From: Jeff Daily Date: Thu, 12 Nov 2020 07:09:19 -0800 Subject: [PATCH 47/93] [ROCm] add skipCUDAIfRocm to test_lingalg test_norm_fro_2_equivalence_old (#47809) Summary: This test started failing when ROCm CI moved to 3.9. Skip until triage is complete. Pull Request resolved: https://github.com/pytorch/pytorch/pull/47809 Reviewed By: seemethere Differential Revision: D24906319 Pulled By: walterddr fbshipit-source-id: 0c425f3b21190cfbc5e0d1c3f477d834af40f0ca --- test/test_linalg.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_linalg.py b/test/test_linalg.py index a64ea5302447..a6a7fa9a8088 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -10,7 +10,7 @@ from torch.testing._internal.common_device_type import \ (instantiate_device_type_tests, dtypes, dtypesIfCUDA, onlyCUDA, onlyCPU, skipCUDAIfNoMagma, skipCPUIfNoLapack, precisionOverride, - skipCUDAIfNoMagmaAndNoCusolver, onlyOnCPUAndCUDA) + skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfRocm, onlyOnCPUAndCUDA) from torch.testing._internal.jit_metaprogramming_utils import gen_script_fn_and_args from torch.autograd import gradcheck, gradgradcheck @@ -886,6 +886,7 @@ def gen_error_message(input_size, p, keepdim, dim=None): # Ensure torch.norm with p='fro' and p=2 give the same results for mutually supported input combinations @dtypes(torch.float) + @skipCUDAIfRocm def test_norm_fro_2_equivalence_old(self, device, dtype): input_sizes = [ (0,), From 859e054314ea99c00f0aac3cbaedd9de3d192e87 Mon Sep 17 00:00:00 2001 From: Kyle Chen Date: Thu, 12 Nov 2020 07:09:22 -0800 Subject: [PATCH 48/93] skip test_all_reduce_sum_cuda_async test case for ROCM (#47630) Summary: Skip the following test case for rocm (When PYTORCH_TEST_WITH_ROCM=1): - test_all_reduce_sum_cuda_async (__main__.TestDistBackendWithFork) jeffdaily pruthvistony Pull Request resolved: https://github.com/pytorch/pytorch/pull/47630 Reviewed By: seemethere, heitorschueroff Differential Revision: D24849755 Pulled By: walterddr fbshipit-source-id: b952c81677df2dfd35d459b94ce0f7a5b12c0d5c --- torch/testing/_internal/distributed/distributed_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index fd14039e1859..641ccd739ec7 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -1378,6 +1378,7 @@ def test_all_reduce_sum_cuda(self): "Only Gloo and NCCL backends will have CUDA allReduce tested", ) @skip_if_no_gpu + @skip_if_rocm def test_all_reduce_sum_cuda_async(self): group, group_id, rank = self._init_global_test() rank_to_GPU = self._init_multigpu_helper() From 553ccccc54c67e1640e6f1df826f183d71265b95 Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Thu, 12 Nov 2020 07:34:13 -0800 Subject: [PATCH 49/93] [c10d] switch ProcessGroup to be managed by intrusive_ptr (#47343) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47343 Test Plan: Imported from OSS Reviewed By: gmagogsfm Differential Revision: D24723418 Pulled By: wanchaol fbshipit-source-id: 0463819b96c53b12bdbb3905431110d7b21beb77 --- test/cpp/rpc/test_e2e_process_group.cpp | 4 +- test/cpp/rpc/test_e2e_tensorpipe.cpp | 4 +- test/cpp_extensions/cpp_c10d_extension.cpp | 4 +- test/cpp_extensions/cpp_c10d_extension.hpp | 8 +- torch/csrc/distributed/c10d/init.cpp | 25 ++-- torch/csrc/distributed/rpc/init.cpp | 4 +- .../distributed/rpc/process_group_agent.cpp | 2 +- .../distributed/rpc/process_group_agent.h | 4 +- .../csrc/distributed/rpc/tensorpipe_agent.cpp | 2 +- torch/csrc/distributed/rpc/tensorpipe_agent.h | 4 +- .../testing/faulty_process_group_agent.cpp | 2 +- .../rpc/testing/faulty_process_group_agent.h | 2 +- torch/csrc/distributed/rpc/testing/init.cpp | 2 +- torch/lib/c10d/ProcessGroup.hpp | 2 +- torch/lib/c10d/ProcessGroupMPI.cpp | 6 +- torch/lib/c10d/ProcessGroupMPI.hpp | 2 +- torch/lib/c10d/ProcessGroupNCCL.cpp | 2 +- torch/lib/c10d/ProcessGroupNCCL.hpp | 4 +- torch/lib/c10d/ProcessGroupRoundRobin.cpp | 4 +- torch/lib/c10d/ProcessGroupRoundRobin.hpp | 8 +- torch/lib/c10d/comm.cpp | 5 +- torch/lib/c10d/comm.hpp | 2 +- torch/lib/c10d/frontend.cpp | 111 +++++++++--------- torch/lib/c10d/frontend.hpp | 108 ++++++++--------- torch/lib/c10d/reducer.cpp | 2 +- torch/lib/c10d/reducer.hpp | 4 +- torch/lib/c10d/test/ProcessGroupMPITest.cpp | 2 +- .../c10d/test/ProcessGroupNCCLErrorsTest.cpp | 4 +- 28 files changed, 166 insertions(+), 167 deletions(-) diff --git a/test/cpp/rpc/test_e2e_process_group.cpp b/test/cpp/rpc/test_e2e_process_group.cpp index 7c5af57d6a09..01bed87687a4 100644 --- a/test/cpp/rpc/test_e2e_process_group.cpp +++ b/test/cpp/rpc/test_e2e_process_group.cpp @@ -22,8 +22,8 @@ class TestE2EProcessGroup : public TestE2EBase { options.timeout = rpcTimeout; // Initialize server rpc agent. - auto pg = - std::make_shared(store, 0, numWorkers, options); + auto pg = c10::make_intrusive( + store, 0, numWorkers, options); rpcAgent = std::make_shared( "worker", diff --git a/test/cpp/rpc/test_e2e_tensorpipe.cpp b/test/cpp/rpc/test_e2e_tensorpipe.cpp index 8fecf6dffb75..d0d00f4cce5a 100644 --- a/test/cpp/rpc/test_e2e_tensorpipe.cpp +++ b/test/cpp/rpc/test_e2e_tensorpipe.cpp @@ -23,8 +23,8 @@ class TestE2ETensorPipe : public TestE2EBase { float rpcTimeout = 30; // Initialize server rpc agent. - auto pg = - std::make_shared(store, 0, numWorkers, options); + auto pg = c10::make_intrusive( + store, 0, numWorkers, options); TensorPipeRpcBackendOptions opts( /*numWorkerThreads=*/std::max(16U, std::thread::hardware_concurrency()), diff --git a/test/cpp_extensions/cpp_c10d_extension.cpp b/test/cpp_extensions/cpp_c10d_extension.cpp index d5ba55a6379c..d01d07f208d7 100644 --- a/test/cpp_extensions/cpp_c10d_extension.cpp +++ b/test/cpp_extensions/cpp_c10d_extension.cpp @@ -107,12 +107,12 @@ c10::intrusive_ptr ProcessGroupTest::recvAnysource( throw std::runtime_error("ProcessGroupTest does not support recvAnysource"); } -std::shared_ptr ProcessGroupTest::createProcessGroupTest( +c10::intrusive_ptr ProcessGroupTest::createProcessGroupTest( const c10::intrusive_ptr<::c10d::Store>& store, int rank, int size, const std::chrono::duration& timeout) { - return std::make_shared(rank, size); + return c10::make_intrusive(rank, size); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { diff --git a/test/cpp_extensions/cpp_c10d_extension.hpp b/test/cpp_extensions/cpp_c10d_extension.hpp index 1773953629d5..6b5070e306e9 100644 --- a/test/cpp_extensions/cpp_c10d_extension.hpp +++ b/test/cpp_extensions/cpp_c10d_extension.hpp @@ -88,19 +88,19 @@ class ProcessGroupTest : public ProcessGroup { c10::intrusive_ptr send( std::vector& tensors, int dstRank, - int tag); + int tag) override; c10::intrusive_ptr recv( std::vector& tensors, int srcRank, - int tag); + int tag) override; c10::intrusive_ptr recvAnysource( std::vector& tensor, - int tag); + int tag) override; // Create a new ProcessGroupTest instance - static std::shared_ptr createProcessGroupTest( + static c10::intrusive_ptr createProcessGroupTest( const c10::intrusive_ptr<::c10d::Store>& store, int rank, int size, diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index dd32ff91603c..f17802c6f0b1 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -223,7 +223,7 @@ An enum-like class for built-in communication hooks: ``ALLREDUCE`` and ``FP16_CO py::init< std::vector>, std::vector>, - std::shared_ptr<::c10d::ProcessGroup>, + c10::intrusive_ptr<::c10d::ProcessGroup>, std::vector>, int64_t, bool, @@ -642,7 +642,7 @@ that adds a prefix to each key inserted to the store. .def(py::init>()); auto processGroup = - shared_ptr_class_<::c10d::ProcessGroup>(module, "ProcessGroup") + intrusive_ptr_class_<::c10d::ProcessGroup>(module, "ProcessGroup") .def("rank", &::c10d::ProcessGroup::getRank) .def("size", &::c10d::ProcessGroup::getSize) @@ -907,13 +907,13 @@ that adds a prefix to each key inserted to the store. #ifndef _WIN32 module.def( "_round_robin_process_groups", - [](std::vector> processGroups) - -> std::shared_ptr<::c10d::ProcessGroup> { + [](std::vector> processGroups) + -> c10::intrusive_ptr<::c10d::ProcessGroup> { if (processGroups.size() == 0) { throw std::invalid_argument("Specify at least 1 process group"); } const auto& first = processGroups.front(); - return std::make_shared<::c10d::ProcessGroupRoundRobin>( + return c10::make_intrusive<::c10d::ProcessGroupRoundRobin>( first->getRank(), first->getSize(), std::move(processGroups)); }, py::arg("process_groups"), @@ -921,7 +921,7 @@ that adds a prefix to each key inserted to the store. #endif #ifdef USE_C10D_GLOO - auto processGroupGloo = shared_ptr_class_<::c10d::ProcessGroupGloo>( + auto processGroupGloo = intrusive_ptr_class_<::c10d::ProcessGroupGloo>( module, "ProcessGroupGloo", processGroup); shared_ptr_class_<::gloo::transport::Device>(processGroupGloo, "Device"); @@ -981,7 +981,7 @@ that adds a prefix to each key inserted to the store. options.timeout = timeout; options.threads = options.devices.size() * 2; - return std::make_shared<::c10d::ProcessGroupGloo>( + return c10::make_intrusive<::c10d::ProcessGroupGloo>( store, rank, size, options); }), py::arg("store"), @@ -993,15 +993,14 @@ that adds a prefix to each key inserted to the store. #ifdef USE_C10D_NCCL auto processGroupNCCL = - shared_ptr_class_<::c10d::ProcessGroupNCCL>( + intrusive_ptr_class_<::c10d::ProcessGroupNCCL>( module, "ProcessGroupNCCL", processGroup) .def( py::init< const c10::intrusive_ptr<::c10d::Store>&, int, int, - const c10::intrusive_ptr< - ::c10d::ProcessGroupNCCL::Options>&>(), + c10::intrusive_ptr<::c10d::ProcessGroupNCCL::Options>>(), py::call_guard()) .def( py::init([](const c10::intrusive_ptr<::c10d::Store>& store, @@ -1011,7 +1010,7 @@ that adds a prefix to each key inserted to the store. auto options = ::c10d::ProcessGroupNCCL::Options::create(); options->isHighPriorityStream = false; options->opTimeout = timeout; - return std::make_shared<::c10d::ProcessGroupNCCL>( + return c10::make_intrusive<::c10d::ProcessGroupNCCL>( store, rank, size, options); }), py::arg("store"), @@ -1036,7 +1035,7 @@ that adds a prefix to each key inserted to the store. #endif #ifdef USE_C10D_MPI - auto processGroupMPI = shared_ptr_class_<::c10d::ProcessGroupMPI>( + auto processGroupMPI = intrusive_ptr_class_<::c10d::ProcessGroupMPI>( module, "ProcessGroupMPI", processGroup); // Define static create function instead of a constructor, because @@ -1149,7 +1148,7 @@ that adds a prefix to each key inserted to the store. // Define a lambda such that the pybind11 prototype can take a std::vector // for the tensor list argument, but still pass it to the underlying // function as a c10::ArrayRef. - [](std::shared_ptr<::c10d::ProcessGroup> process_group, + [](c10::intrusive_ptr<::c10d::ProcessGroup> process_group, std::vector tensors, // NOLINT size_t buffer_size, int rank) { diff --git a/torch/csrc/distributed/rpc/init.cpp b/torch/csrc/distributed/rpc/init.cpp index 81af4abebd5f..9b28ecbdd4bb 100644 --- a/torch/csrc/distributed/rpc/init.cpp +++ b/torch/csrc/distributed/rpc/init.cpp @@ -494,7 +494,7 @@ PyObject* rpc_init(PyObject* _unused, PyObject* noargs) { shared_ptr_class_(module, "ProcessGroupAgent", rpcAgent) .def(py::init([](std::string workerName, - const std::shared_ptr<::c10d::ProcessGroup>& pg, + const c10::intrusive_ptr<::c10d::ProcessGroup>& pg, int numSendRecvThreads, std::chrono::milliseconds rpcTimeout) { return std::make_unique( @@ -580,7 +580,7 @@ PyObject* rpc_init(PyObject* _unused, PyObject* noargs) { std::string selfName, worker_id_t selfId, int worldSize, - std::shared_ptr<::c10d::ProcessGroup> processGroup, + c10::intrusive_ptr<::c10d::ProcessGroup> processGroup, TensorPipeRpcBackendOptions opts) { return std::make_shared( store, diff --git a/torch/csrc/distributed/rpc/process_group_agent.cpp b/torch/csrc/distributed/rpc/process_group_agent.cpp index 13e685b8fe74..b106f1442d31 100644 --- a/torch/csrc/distributed/rpc/process_group_agent.cpp +++ b/torch/csrc/distributed/rpc/process_group_agent.cpp @@ -90,7 +90,7 @@ void ProcessGroupAgent::collectNames() { ProcessGroupAgent::ProcessGroupAgent( std::string workerName, - std::shared_ptr pg, + c10::intrusive_ptr<::c10d::ProcessGroup> pg, int numSendRecvThreads, std::chrono::milliseconds rpcTimeout, std::unique_ptr cb) diff --git a/torch/csrc/distributed/rpc/process_group_agent.h b/torch/csrc/distributed/rpc/process_group_agent.h index 70fb1b40244d..61d17f03e623 100644 --- a/torch/csrc/distributed/rpc/process_group_agent.h +++ b/torch/csrc/distributed/rpc/process_group_agent.h @@ -61,7 +61,7 @@ class TORCH_API ProcessGroupAgent : public RpcAgent { public: ProcessGroupAgent( std::string workerName, - std::shared_ptr pg, + c10::intrusive_ptr<::c10d::ProcessGroup> pg, int numSendRecvThreads, std::chrono::milliseconds rpcTimeout, std::unique_ptr cb); @@ -209,7 +209,7 @@ class TORCH_API ProcessGroupAgent : public RpcAgent { return ++nextId_; } - std::shared_ptr pg_; + c10::intrusive_ptr<::c10d::ProcessGroup> pg_; // worker name -> rank std::unordered_map nameMap_; std::vector allWorkerInfo_; diff --git a/torch/csrc/distributed/rpc/tensorpipe_agent.cpp b/torch/csrc/distributed/rpc/tensorpipe_agent.cpp index eff1e7ebdf21..16510d3315cb 100644 --- a/torch/csrc/distributed/rpc/tensorpipe_agent.cpp +++ b/torch/csrc/distributed/rpc/tensorpipe_agent.cpp @@ -224,7 +224,7 @@ TensorPipeAgent::TensorPipeAgent( std::string selfName, worker_id_t selfId, int worldSize, - std::shared_ptr processGroup, + c10::intrusive_ptr<::c10d::ProcessGroup> processGroup, TensorPipeRpcBackendOptions opts, std::unique_ptr cb) : RpcAgent( diff --git a/torch/csrc/distributed/rpc/tensorpipe_agent.h b/torch/csrc/distributed/rpc/tensorpipe_agent.h index b8c9a8c64e5c..1eb5cab5be82 100644 --- a/torch/csrc/distributed/rpc/tensorpipe_agent.h +++ b/torch/csrc/distributed/rpc/tensorpipe_agent.h @@ -145,7 +145,7 @@ class TensorPipeAgent : public RpcAgent { std::string selfName, worker_id_t selfId, int worldSize, - std::shared_ptr processGroup, + c10::intrusive_ptr<::c10d::ProcessGroup> processGroup, TensorPipeRpcBackendOptions opts, std::unique_ptr cb); @@ -283,7 +283,7 @@ class TensorPipeAgent : public RpcAgent { // The join method is required to behave like a barrier and perform collective // operations. For simplicity and reliability, we offload this to a process // group, but probably one day we might want to re-implement them using RPCs. - const std::shared_ptr processGroup_; + const c10::intrusive_ptr<::c10d::ProcessGroup> processGroup_; mutable std::mutex mutex_; uint64_t nextMessageID_{0}; diff --git a/torch/csrc/distributed/rpc/testing/faulty_process_group_agent.cpp b/torch/csrc/distributed/rpc/testing/faulty_process_group_agent.cpp index a1be688a285e..dccf1abb6d3e 100644 --- a/torch/csrc/distributed/rpc/testing/faulty_process_group_agent.cpp +++ b/torch/csrc/distributed/rpc/testing/faulty_process_group_agent.cpp @@ -12,7 +12,7 @@ std::string fromVec(const std::vector& vec) { FaultyProcessGroupAgent::FaultyProcessGroupAgent( std::string workerName, - std::shared_ptr pg, + c10::intrusive_ptr<::c10d::ProcessGroup> pg, int numSendRecvThreads, std::chrono::milliseconds rpcTimeout, const std::vector& messagesToFail, diff --git a/torch/csrc/distributed/rpc/testing/faulty_process_group_agent.h b/torch/csrc/distributed/rpc/testing/faulty_process_group_agent.h index f240f6847c44..25a162bbd559 100644 --- a/torch/csrc/distributed/rpc/testing/faulty_process_group_agent.h +++ b/torch/csrc/distributed/rpc/testing/faulty_process_group_agent.h @@ -35,7 +35,7 @@ class FaultyProcessGroupAgent : public ProcessGroupAgent { public: FaultyProcessGroupAgent( std::string workerName, - std::shared_ptr pg, + c10::intrusive_ptr pg, int numSendRecvThreads, std::chrono::milliseconds rpcTimeout, const std::vector& messagesToFail, diff --git a/torch/csrc/distributed/rpc/testing/init.cpp b/torch/csrc/distributed/rpc/testing/init.cpp index a662faed88ba..da9b3477d51a 100644 --- a/torch/csrc/distributed/rpc/testing/init.cpp +++ b/torch/csrc/distributed/rpc/testing/init.cpp @@ -66,7 +66,7 @@ PyObject* faulty_agent_init(PyObject* _unused, PyObject* noargs) { .def( py::init< std::string, - std::shared_ptr<::c10d::ProcessGroup>, + c10::intrusive_ptr<::c10d::ProcessGroup>, int, std::chrono::milliseconds, const std::vector&, diff --git a/torch/lib/c10d/ProcessGroup.hpp b/torch/lib/c10d/ProcessGroup.hpp index 63996b516a06..ea4b1428038a 100644 --- a/torch/lib/c10d/ProcessGroup.hpp +++ b/torch/lib/c10d/ProcessGroup.hpp @@ -68,7 +68,7 @@ bool isP2POp(OpType opType); // process group to find each other (referred to as rendezvous from // hereon) // -class ProcessGroup { +class ProcessGroup : public torch::CustomClassHolder { public: // Please do not use ProcessGroup::Work API, it is going away, to be // replaced by ivalue::Future. diff --git a/torch/lib/c10d/ProcessGroupMPI.cpp b/torch/lib/c10d/ProcessGroupMPI.cpp index 5f9d0be41b8f..250b635e8c69 100644 --- a/torch/lib/c10d/ProcessGroupMPI.cpp +++ b/torch/lib/c10d/ProcessGroupMPI.cpp @@ -199,7 +199,7 @@ void ProcessGroupMPI::initMPIOnce() { }); } -std::shared_ptr ProcessGroupMPI::createProcessGroupMPI( +c10::intrusive_ptr ProcessGroupMPI::createProcessGroupMPI( std::vector ranks) { // Once initialization initMPIOnce(); @@ -238,10 +238,10 @@ std::shared_ptr ProcessGroupMPI::createProcessGroupMPI( // process group instance. This is in line with the semantics of the // other process group types. if (groupComm == MPI_COMM_NULL) { - return std::shared_ptr(); + return c10::intrusive_ptr(); } - return std::make_shared(rank, size, groupComm); + return c10::make_intrusive(rank, size, groupComm); } ProcessGroupMPI::ProcessGroupMPI(int rank, int size, MPI_Comm pgComm) diff --git a/torch/lib/c10d/ProcessGroupMPI.hpp b/torch/lib/c10d/ProcessGroupMPI.hpp index 48d95eada887..16f25e5c4895 100644 --- a/torch/lib/c10d/ProcessGroupMPI.hpp +++ b/torch/lib/c10d/ProcessGroupMPI.hpp @@ -185,7 +185,7 @@ class ProcessGroupMPI : public ProcessGroup { const BarrierOptions& opts = BarrierOptions()) override; // Creating a new ProcessGroupMPI, will initiialize MPI if not initialized - static std::shared_ptr createProcessGroupMPI( + static c10::intrusive_ptr createProcessGroupMPI( std::vector ranks = {}); protected: diff --git a/torch/lib/c10d/ProcessGroupNCCL.cpp b/torch/lib/c10d/ProcessGroupNCCL.cpp index 59219c07b32f..c3a245fb13da 100644 --- a/torch/lib/c10d/ProcessGroupNCCL.cpp +++ b/torch/lib/c10d/ProcessGroupNCCL.cpp @@ -440,7 +440,7 @@ ProcessGroupNCCL::ProcessGroupNCCL( const c10::intrusive_ptr& store, int rank, int size, - const c10::intrusive_ptr& options) + c10::intrusive_ptr options) : ProcessGroup(rank, size), store_(store), ncclCommCounter_(0), diff --git a/torch/lib/c10d/ProcessGroupNCCL.hpp b/torch/lib/c10d/ProcessGroupNCCL.hpp index b84cc4deb051..fd57f105df0b 100644 --- a/torch/lib/c10d/ProcessGroupNCCL.hpp +++ b/torch/lib/c10d/ProcessGroupNCCL.hpp @@ -406,7 +406,7 @@ class ProcessGroupNCCL : public ProcessGroup { const c10::intrusive_ptr& store, int rank, int size, - const c10::intrusive_ptr& options = Options::create()); + c10::intrusive_ptr options = Options::create()); // This constructor includes the deprecated `groupName` argument. // If you have existing code that uses the `groupName`, you can replace @@ -416,7 +416,7 @@ class ProcessGroupNCCL : public ProcessGroup { int rank, int size, const std::string& groupName, - const c10::intrusive_ptr& options = Options::create()) + c10::intrusive_ptr options = Options::create()) : ProcessGroupNCCL(store, rank, size, options) {} virtual ~ProcessGroupNCCL(); diff --git a/torch/lib/c10d/ProcessGroupRoundRobin.cpp b/torch/lib/c10d/ProcessGroupRoundRobin.cpp index c77188577a62..455c1654f587 100644 --- a/torch/lib/c10d/ProcessGroupRoundRobin.cpp +++ b/torch/lib/c10d/ProcessGroupRoundRobin.cpp @@ -5,7 +5,7 @@ namespace c10d { ProcessGroupRoundRobin::ProcessGroupRoundRobin( int rank, int size, - std::vector> processGroups) + std::vector> processGroups) : ProcessGroup(rank, size), processGroups_(std::move(processGroups)) { TORCH_CHECK(processGroups_.size() >= 1); for (const auto& processGroup : processGroups_) { @@ -111,7 +111,7 @@ c10::intrusive_ptr ProcessGroupRoundRobin::barrier( throw std::runtime_error("ProcessGroupRoundRobin does not support barrier"); }; -const std::shared_ptr& ProcessGroupRoundRobin::next() { +const c10::intrusive_ptr& ProcessGroupRoundRobin::next() { auto& processGroup = *iterator_; iterator_++; if (iterator_ == processGroups_.end()) { diff --git a/torch/lib/c10d/ProcessGroupRoundRobin.hpp b/torch/lib/c10d/ProcessGroupRoundRobin.hpp index 62d59ef18ce5..a8c2eba115a6 100644 --- a/torch/lib/c10d/ProcessGroupRoundRobin.hpp +++ b/torch/lib/c10d/ProcessGroupRoundRobin.hpp @@ -21,7 +21,7 @@ class ProcessGroupRoundRobin final : public ProcessGroup { explicit ProcessGroupRoundRobin( int rank, int size, - std::vector> processGroups); + std::vector> processGroups); ~ProcessGroupRoundRobin() override; @@ -97,11 +97,11 @@ class ProcessGroupRoundRobin final : public ProcessGroup { const BarrierOptions& opts = BarrierOptions()) override; private: - std::vector> processGroups_; - std::vector>::const_iterator iterator_; + std::vector> processGroups_; + std::vector>::const_iterator iterator_; // Returns the next ProcessGroup to use. - const std::shared_ptr& next(); + const c10::intrusive_ptr& next(); }; } // namespace c10d diff --git a/torch/lib/c10d/comm.cpp b/torch/lib/c10d/comm.cpp index 5ef88f058aca..1db8901b2859 100644 --- a/torch/lib/c10d/comm.cpp +++ b/torch/lib/c10d/comm.cpp @@ -13,7 +13,7 @@ namespace { class BroadcastWork { public: BroadcastWork( - const std::shared_ptr& process_group, + const c10::intrusive_ptr& process_group, std::vector bucket_tensors, int root_rank = 0) : bucket_tensors_(std::move(bucket_tensors)), @@ -55,7 +55,7 @@ class BroadcastWork { // Broadcast many tensors to all processes in the process group. void broadcast_coalesced( - std::shared_ptr process_group, + c10::intrusive_ptr process_group, at::TensorList tensors, size_t buffer_size, int rank) { @@ -87,5 +87,4 @@ void broadcast_coalesced( } } - } // namespace c10d diff --git a/torch/lib/c10d/comm.hpp b/torch/lib/c10d/comm.hpp index 0dd5815a5b8b..e1bde1f03ec0 100644 --- a/torch/lib/c10d/comm.hpp +++ b/torch/lib/c10d/comm.hpp @@ -8,7 +8,7 @@ namespace c10d { // Broadcast many tensors to all processes in the process group. void broadcast_coalesced( - std::shared_ptr process_group, + c10::intrusive_ptr process_group, at::TensorList tensors, size_t buffer_size, int rank = 0); diff --git a/torch/lib/c10d/frontend.cpp b/torch/lib/c10d/frontend.cpp index 2a5e3d92f407..0667e2c98a00 100644 --- a/torch/lib/c10d/frontend.cpp +++ b/torch/lib/c10d/frontend.cpp @@ -44,7 +44,8 @@ bool assertReduceOpSupportsComplexTensor(ReduceOp op) { // we need many additional conditionals to check whether group is WORLD and // then use default_pg_ explicitly. -int64_t DistributedC10d::getRank(const std::shared_ptr& group) const { +int64_t DistributedC10d::getRank( + const c10::intrusive_ptr& group) const { if (rankNotInGroup(group)) { return -1; } @@ -53,7 +54,7 @@ int64_t DistributedC10d::getRank(const std::shared_ptr& group) con } int64_t DistributedC10d::getWorldSize( - const std::shared_ptr& group) const { + const c10::intrusive_ptr& group) const { if (rankNotInGroup(group)) { return -1; } @@ -62,7 +63,7 @@ int64_t DistributedC10d::getWorldSize( } int64_t DistributedC10d::getGroupSize( - const std::shared_ptr& group) const { + const c10::intrusive_ptr& group) const { if (group == default_pg_) { default_pg_->getSize(); } @@ -73,13 +74,13 @@ int64_t DistributedC10d::getGroupSize( return it->second.size(); } -std::shared_ptr DistributedC10d::worldProcessGroup() { +c10::intrusive_ptr DistributedC10d::worldProcessGroup() { checkDefaultPg(); return default_pg_; } bool DistributedC10d::rankNotInGroup( - const std::shared_ptr& group) const { + const c10::intrusive_ptr& group) const { if (group == default_pg_) { return false; } @@ -87,7 +88,7 @@ bool DistributedC10d::rankNotInGroup( } int64_t DistributedC10d::getGroupRank( - const std::shared_ptr& group, + const c10::intrusive_ptr& group, const int64_t rank) const { TORCH_CHECK( group != default_pg_, @@ -117,7 +118,7 @@ int64_t DistributedC10d::getGroupRank( } int64_t DistributedC10d::getGlobalRank( - const std::shared_ptr& group, + const c10::intrusive_ptr& group, const int64_t group_rank) const { TORCH_CHECK( group != default_pg_, @@ -137,7 +138,7 @@ int64_t DistributedC10d::getGlobalRank( } std::string DistributedC10d::getBackend( - const std::shared_ptr& group) { + const c10::intrusive_ptr& group) { TORCH_CHECK(!rankNotInGroup(group), "Invalid process group specified"); auto it = pg_map_.find(group); @@ -146,10 +147,10 @@ std::string DistributedC10d::getBackend( return it->second.first; } -std::shared_ptr DistributedC10d::isend( +c10::intrusive_ptr DistributedC10d::isend( at::Tensor tensor, int64_t dst, - const std::shared_ptr& group, + const c10::intrusive_ptr& group, c10::optional& tag) { if (rankNotInGroup(group)) { return nullptr; @@ -166,10 +167,10 @@ std::shared_ptr DistributedC10d::isend( return group->send(inputs, group_dst_rank, tag.value_or(0)); } -std::shared_ptr DistributedC10d::irecv( +c10::intrusive_ptr DistributedC10d::irecv( at::Tensor tensor, int64_t src, - const std::shared_ptr& group, + const c10::intrusive_ptr& group, c10::optional& tag) { if (rankNotInGroup(group)) { return nullptr; @@ -189,7 +190,7 @@ std::shared_ptr DistributedC10d::irecv( void DistributedC10d::send( at::Tensor tensor, int64_t dst, - const std::shared_ptr& group, + const c10::intrusive_ptr& group, c10::optional& tag) { auto work = isend(std::move(tensor), dst, group, tag); if (work) { @@ -200,7 +201,7 @@ void DistributedC10d::send( int64_t DistributedC10d::recv( at::Tensor tensor, const c10::optional& src, - const std::shared_ptr& group, + const c10::intrusive_ptr& group, c10::optional& tag) { if (rankNotInGroup(group)) { return -1; @@ -228,10 +229,10 @@ int64_t DistributedC10d::recv( return src.value(); } -std::shared_ptr DistributedC10d::broadcastMultiGPU( +c10::intrusive_ptr DistributedC10d::broadcastMultiGPU( std::vector& tensor_list, int64_t src, - const std::shared_ptr& group, + const c10::intrusive_ptr& group, bool async_op, int64_t src_tensor) { if (rankNotInGroup(group)) { @@ -243,7 +244,7 @@ std::shared_ptr DistributedC10d::broadcastMultiGPU( opts.rootTensor = src_tensor; checkDefaultPg(); - std::shared_ptr work; + c10::intrusive_ptr work; if (group == default_pg_) { work = default_pg_->broadcast(tensor_list, opts); } else { @@ -259,10 +260,10 @@ std::shared_ptr DistributedC10d::broadcastMultiGPU( return nullptr; } -std::shared_ptr DistributedC10d::broadcast( +c10::intrusive_ptr DistributedC10d::broadcast( at::Tensor tensor, int64_t src, - const std::shared_ptr& group, + const c10::intrusive_ptr& group, bool async_op) { if (rankNotInGroup(group)) { return nullptr; @@ -273,7 +274,7 @@ std::shared_ptr DistributedC10d::broadcast( opts.rootTensor = 0; std::vector tensors = {std::move(tensor)}; - std::shared_ptr work; + c10::intrusive_ptr work; checkDefaultPg(); if (group == default_pg_) { work = group->broadcast(tensors, opts); @@ -290,9 +291,9 @@ std::shared_ptr DistributedC10d::broadcast( return nullptr; } -std::shared_ptr DistributedC10d::allReduceMultiGPU( +c10::intrusive_ptr DistributedC10d::allReduceMultiGPU( std::vector& tensor_list, - const std::shared_ptr& group, + const c10::intrusive_ptr& group, ReduceOp op, bool async_op) { if (rankNotInGroup(group)) { @@ -313,9 +314,9 @@ std::shared_ptr DistributedC10d::allReduceMultiGPU( return nullptr; } -std::shared_ptr DistributedC10d::allReduce( +c10::intrusive_ptr DistributedC10d::allReduce( at::Tensor tensor, - const std::shared_ptr& group, + const c10::intrusive_ptr& group, ReduceOp op, bool async_op) { if (rankNotInGroup(group)) { @@ -337,9 +338,9 @@ std::shared_ptr DistributedC10d::allReduce( return nullptr; } -std::shared_ptr DistributedC10d::allReduceCoalesced( +c10::intrusive_ptr DistributedC10d::allReduceCoalesced( std::vector& tensors, - const std::shared_ptr& group, + const c10::intrusive_ptr& group, ReduceOp op, bool async_op) { if (rankNotInGroup(group)) { @@ -360,12 +361,12 @@ std::shared_ptr DistributedC10d::allReduceCoalesced( return nullptr; } -std::shared_ptr DistributedC10d::reduceMultiGPU( +c10::intrusive_ptr DistributedC10d::reduceMultiGPU( std::vector& tensor_list, int64_t dst, - const std::shared_ptr& group, + const c10::intrusive_ptr& group, ReduceOp op, - bool async_op , + bool async_op, int64_t dst_tensor) { if (rankNotInGroup(group)) { return nullptr; @@ -378,7 +379,7 @@ std::shared_ptr DistributedC10d::reduceMultiGPU( checkDefaultPg(); - std::shared_ptr work; + c10::intrusive_ptr work; if (group == default_pg_) { work = group->reduce(tensor_list, opts); } else { @@ -394,10 +395,10 @@ std::shared_ptr DistributedC10d::reduceMultiGPU( return nullptr; } -std::shared_ptr DistributedC10d::reduce( +c10::intrusive_ptr DistributedC10d::reduce( at::Tensor tensor, int64_t dst, - const std::shared_ptr& group, + const c10::intrusive_ptr& group, ReduceOp op, bool async_op) { if (rankNotInGroup(group)) { @@ -409,7 +410,7 @@ std::shared_ptr DistributedC10d::reduce( opts.rootRank = dst; checkDefaultPg(); - std::shared_ptr work; + c10::intrusive_ptr work; std::vector tensors = {std::move(tensor)}; if (group == default_pg_) { work = group->reduce(tensors, opts); @@ -426,10 +427,10 @@ std::shared_ptr DistributedC10d::reduce( return nullptr; } -std::shared_ptr DistributedC10d::allGatherMultiGPU( +c10::intrusive_ptr DistributedC10d::allGatherMultiGPU( std::vector>& output_tensor_lists, std::vector& input_tensor_list, - const std::shared_ptr& group, + const c10::intrusive_ptr& group, bool async_op) { if (rankNotInGroup(group)) { return nullptr; @@ -447,10 +448,10 @@ std::shared_ptr DistributedC10d::allGatherMultiGPU( return nullptr; } -std::shared_ptr DistributedC10d::allGather( +c10::intrusive_ptr DistributedC10d::allGather( std::vector& tensor_list, at::Tensor tensor, - const std::shared_ptr& group, + const c10::intrusive_ptr& group, bool async_op) { if (rankNotInGroup(group)) { return nullptr; @@ -470,10 +471,10 @@ std::shared_ptr DistributedC10d::allGather( return nullptr; } -std::shared_ptr DistributedC10d::allGatherCoalesced( +c10::intrusive_ptr DistributedC10d::allGatherCoalesced( std::vector>& output_tensor_lists, std::vector& input_tensor_list, - const std::shared_ptr& group, + const c10::intrusive_ptr& group, bool async_op) { if (rankNotInGroup(group)) { return nullptr; @@ -492,10 +493,10 @@ std::shared_ptr DistributedC10d::allGatherCoalesced( return nullptr; } -std::shared_ptr DistributedC10d::gather( +c10::intrusive_ptr DistributedC10d::gather( at::Tensor tensor, const c10::optional>& gather_list, - const std::shared_ptr& group, + const c10::intrusive_ptr& group, int64_t dst, bool async_op) { if (rankNotInGroup(group)) { @@ -522,7 +523,7 @@ std::shared_ptr DistributedC10d::gather( GatherOptions opts; opts.rootRank = dst; - std::shared_ptr work; + c10::intrusive_ptr work; if (group == default_pg_) { work = group->gather(output_tensors, input_tensors, opts); } else { @@ -538,10 +539,10 @@ std::shared_ptr DistributedC10d::gather( return nullptr; } -std::shared_ptr DistributedC10d::scatter( +c10::intrusive_ptr DistributedC10d::scatter( at::Tensor tensor, std::vector& scatter_list, - const std::shared_ptr& group, + const c10::intrusive_ptr& group, int64_t src, bool async_op) { if (rankNotInGroup(group)) { @@ -559,7 +560,7 @@ std::shared_ptr DistributedC10d::scatter( ScatterOptions opts; opts.rootRank = src; - std::shared_ptr work; + c10::intrusive_ptr work; if (group == default_pg_) { work = group->scatter(output_tensors, input_tensors, opts); } else { @@ -575,10 +576,10 @@ std::shared_ptr DistributedC10d::scatter( return nullptr; } -std::shared_ptr DistributedC10d::reduceScatterMultiGPU( +c10::intrusive_ptr DistributedC10d::reduceScatterMultiGPU( std::vector& output_tensor_list, std::vector>& input_tensor_lists, - const std::shared_ptr& group, + const c10::intrusive_ptr& group, ReduceOp op, bool async_op) { if (rankNotInGroup(group)) { @@ -598,10 +599,10 @@ std::shared_ptr DistributedC10d::reduceScatterMultiGPU( return nullptr; } -std::shared_ptr DistributedC10d::reduceScatter( +c10::intrusive_ptr DistributedC10d::reduceScatter( at::Tensor output, std::vector& input_tensor_list, - const std::shared_ptr& group, + const c10::intrusive_ptr& group, ReduceOp op, bool async_op) { if (rankNotInGroup(group)) { @@ -624,12 +625,12 @@ std::shared_ptr DistributedC10d::reduceScatter( return nullptr; } -std::shared_ptr DistributedC10d::allToAllSingle( +c10::intrusive_ptr DistributedC10d::allToAllSingle( at::Tensor output, at::Tensor input, std::vector& output_split_sizes, std::vector& input_split_sizes, - const std::shared_ptr& group, + const c10::intrusive_ptr& group, bool async_op) { if (rankNotInGroup(group)) { return nullptr; @@ -646,10 +647,10 @@ std::shared_ptr DistributedC10d::allToAllSingle( return nullptr; } -std::shared_ptr DistributedC10d::allToAll( +c10::intrusive_ptr DistributedC10d::allToAll( std::vector& output_tensor_list, std::vector& input_tensor_list, - const std::shared_ptr& group, + const c10::intrusive_ptr& group, bool async_op) { if (rankNotInGroup(group)) { return nullptr; @@ -665,8 +666,8 @@ std::shared_ptr DistributedC10d::allToAll( return nullptr; } -std::shared_ptr DistributedC10d::barrier( - const std::shared_ptr& group, +c10::intrusive_ptr DistributedC10d::barrier( + const c10::intrusive_ptr& group, bool async_op) { if (rankNotInGroup(group)) { return nullptr; diff --git a/torch/lib/c10d/frontend.hpp b/torch/lib/c10d/frontend.hpp index 3449ee30b5ef..328bd31dbf36 100644 --- a/torch/lib/c10d/frontend.hpp +++ b/torch/lib/c10d/frontend.hpp @@ -38,184 +38,184 @@ class DistributedC10d { c10::intrusive_ptr store, const std::string& group_name); - void destroyProcessGroup(std::shared_ptr group); - int64_t getRank(const std::shared_ptr& group) const; - int64_t getWorldSize(const std::shared_ptr& group) const; + void destroyProcessGroup(c10::intrusive_ptr group); + int64_t getRank(const c10::intrusive_ptr& group) const; + int64_t getWorldSize(const c10::intrusive_ptr& group) const; - std::shared_ptr isend( + c10::intrusive_ptr isend( at::Tensor tensor, int64_t dst, - const std::shared_ptr& group, + const c10::intrusive_ptr& group, c10::optional& tag); - std::shared_ptr irecv( + c10::intrusive_ptr irecv( at::Tensor tensor, int64_t src, - const std::shared_ptr& group, + const c10::intrusive_ptr& group, c10::optional& tag); void send( at::Tensor tensor, int64_t dst, - const std::shared_ptr& group, + const c10::intrusive_ptr& group, c10::optional& tag); int64_t recv( at::Tensor tensor, const c10::optional& src, - const std::shared_ptr& group, + const c10::intrusive_ptr& group, c10::optional& tag); - std::shared_ptr broadcastMultiGPU( + c10::intrusive_ptr broadcastMultiGPU( std::vector& tensor_list, int64_t src, - const std::shared_ptr& group, + const c10::intrusive_ptr& group, bool async_op = false, int64_t src_tensor = 0); - std::shared_ptr broadcast( + c10::intrusive_ptr broadcast( at::Tensor tensor, int64_t src, - const std::shared_ptr& group, + const c10::intrusive_ptr& group, bool async_op = false); - std::shared_ptr allReduceMultiGPU( + c10::intrusive_ptr allReduceMultiGPU( std::vector& tensor_list, - const std::shared_ptr& group, + const c10::intrusive_ptr& group, ReduceOp op = ReduceOp::SUM, bool async_op = false); - std::shared_ptr allReduce( + c10::intrusive_ptr allReduce( at::Tensor tensor, - const std::shared_ptr& group, + const c10::intrusive_ptr& group, ReduceOp op = ReduceOp::SUM, bool async_op = false); - std::shared_ptr allReduceCoalesced( + c10::intrusive_ptr allReduceCoalesced( std::vector& tensors, - const std::shared_ptr& group, + const c10::intrusive_ptr& group, ReduceOp op = ReduceOp::SUM, bool async_op = false); - std::shared_ptr reduceMultiGPU( + c10::intrusive_ptr reduceMultiGPU( std::vector& tensor_list, int64_t dst, - const std::shared_ptr& group, + const c10::intrusive_ptr& group, ReduceOp op = ReduceOp::SUM, bool async_op = false, int64_t dst_tensor = 0); - std::shared_ptr reduce( + c10::intrusive_ptr reduce( at::Tensor tensor, int64_t dst, - const std::shared_ptr& group, + const c10::intrusive_ptr& group, ReduceOp op = ReduceOp::SUM, bool async_op = false); - std::shared_ptr allGatherMultiGPU( + c10::intrusive_ptr allGatherMultiGPU( std::vector>& output_tensor_lists, std::vector& input_tensor_list, - const std::shared_ptr& group, + const c10::intrusive_ptr& group, bool async_op = false); - std::shared_ptr allGather( + c10::intrusive_ptr allGather( std::vector& tensor_list, at::Tensor tensor, - const std::shared_ptr& group, + const c10::intrusive_ptr& group, bool async_op = false); - std::shared_ptr allGatherCoalesced( + c10::intrusive_ptr allGatherCoalesced( std::vector>& output_tensor_lists, std::vector& input_tensor_list, - const std::shared_ptr& group, + const c10::intrusive_ptr& group, bool async_op = false); - std::shared_ptr gather( + c10::intrusive_ptr gather( at::Tensor tensor, const c10::optional>& gather_list, - const std::shared_ptr& group, + const c10::intrusive_ptr& group, int64_t dst = 0, bool async_op = false); - std::shared_ptr scatter( + c10::intrusive_ptr scatter( at::Tensor tensor, std::vector& scatter_list, - const std::shared_ptr& group, + const c10::intrusive_ptr& group, int64_t src = 0, bool async_op = false); - std::shared_ptr reduceScatterMultiGPU( + c10::intrusive_ptr reduceScatterMultiGPU( std::vector& output_tensor_list, std::vector>& input_tensor_lists, - const std::shared_ptr& group, + const c10::intrusive_ptr& group, ReduceOp op = ReduceOp::SUM, bool async_op = false); - std::shared_ptr reduceScatter( + c10::intrusive_ptr reduceScatter( at::Tensor output, std::vector& input_tensor_list, - const std::shared_ptr& group, + const c10::intrusive_ptr& group, ReduceOp op = ReduceOp::SUM, bool async_op = false); - std::shared_ptr allToAllSingle( + c10::intrusive_ptr allToAllSingle( at::Tensor output, at::Tensor input, std::vector& output_split_sizes, std::vector& input_split_sizes, - const std::shared_ptr& group, + const c10::intrusive_ptr& group, bool async_op = false); - std::shared_ptr allToAll( + c10::intrusive_ptr allToAll( std::vector& output_tensor_list, std::vector& input_tensor_list, - const std::shared_ptr& group, + const c10::intrusive_ptr& group, bool async_op = false); - std::shared_ptr barrier( - const std::shared_ptr& group, + c10::intrusive_ptr barrier( + const c10::intrusive_ptr& group, bool async_op = false); - std::shared_ptr newGroup( + c10::intrusive_ptr newGroup( std::vector ranks, std::chrono::milliseconds timeout, Backend backend); - std::shared_ptr worldProcessGroup(); + c10::intrusive_ptr worldProcessGroup(); private: DistributedC10d(){}; - bool rankNotInGroup(const std::shared_ptr& group) const; + bool rankNotInGroup(const c10::intrusive_ptr& group) const; int64_t getGroupRank( - const std::shared_ptr& group, + const c10::intrusive_ptr& group, const int64_t rank) const; int64_t getGlobalRank( - const std::shared_ptr& group, + const c10::intrusive_ptr& group, const int64_t group_rank) const; void checkDefaultPg() const; - int64_t getGroupSize(const std::shared_ptr& group) const; - std::string getBackend(const std::shared_ptr& group); + int64_t getGroupSize(const c10::intrusive_ptr& group) const; + std::string getBackend(const c10::intrusive_ptr& group); std::string backend_; // TODO: Ask Alex what kind of equality we need. It determine whether we // need to use ProcessGroup or ProcesGroup* as key. std::unordered_map< - std::shared_ptr, + c10::intrusive_ptr, std::pair>> pg_map_; // Note, this is different mapping relationship than original Python // implementation. - std::unordered_map, std::string> pg_names_; + std::unordered_map, std::string> pg_names_; // Process group's global rank to local rank mapping std::unordered_map< - std::shared_ptr, + c10::intrusive_ptr, std::unordered_map> pg_group_ranks_; - std::shared_ptr default_pg_; + c10::intrusive_ptr default_pg_; // Default value should be "env://" std::string default_pg_init_method_; diff --git a/torch/lib/c10d/reducer.cpp b/torch/lib/c10d/reducer.cpp index c5ee54a9ee8e..d0edd904ca94 100644 --- a/torch/lib/c10d/reducer.cpp +++ b/torch/lib/c10d/reducer.cpp @@ -29,7 +29,7 @@ constexpr int kUnsetDivFactor = -1; Reducer::Reducer( std::vector> replicas, std::vector> bucket_indices, - std::shared_ptr process_group, + c10::intrusive_ptr process_group, std::vector> expect_sparse_gradients, int64_t bucket_bytes_cap, bool find_unused_parameters, diff --git a/torch/lib/c10d/reducer.hpp b/torch/lib/c10d/reducer.hpp index e0fe0004f88e..ada39844a9ca 100644 --- a/torch/lib/c10d/reducer.hpp +++ b/torch/lib/c10d/reducer.hpp @@ -29,7 +29,7 @@ class Reducer { explicit Reducer( std::vector> replicas, std::vector> bucket_indices, - std::shared_ptr process_group, + c10::intrusive_ptr process_group, std::vector> expect_sparse_gradients, int64_t bucket_bytes_cap, bool find_unused_parameters, @@ -125,7 +125,7 @@ class Reducer { mutable std::mutex mutex_; std::vector> replicas_; - std::shared_ptr process_group_; + c10::intrusive_ptr<::c10d::ProcessGroup> process_group_; std::vector> expect_sparse_gradients_; std::vector>> diff --git a/torch/lib/c10d/test/ProcessGroupMPITest.cpp b/torch/lib/c10d/test/ProcessGroupMPITest.cpp index 6c60b3d6742d..5503b4cde866 100644 --- a/torch/lib/c10d/test/ProcessGroupMPITest.cpp +++ b/torch/lib/c10d/test/ProcessGroupMPITest.cpp @@ -13,7 +13,7 @@ // Wait for work to complete void waitWork( - std::shared_ptr pg, + c10::intrusive_ptr<::c10d::ProcessGroupMPI> pg, std::vector> works) { for (auto& work : works) { try { diff --git a/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp b/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp index 82ca25049c63..3dbd26655391 100644 --- a/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp +++ b/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp @@ -40,7 +40,7 @@ class ProcessGroupNCCLSimulateErrors : public c10d::ProcessGroupNCCL { const c10::intrusive_ptr& store, int rank, int size, - const c10::intrusive_ptr& opts) + c10::intrusive_ptr opts) : ProcessGroupNCCL(store, rank, size, opts), simulate_error_(false) {} std::exception_ptr checkForNCCLErrors( @@ -109,7 +109,7 @@ class ProcessGroupNCCLTimedOutErrors : public ProcessGroupNCCLSimulateErrors { const c10::intrusive_ptr& store, int rank, int size, - const c10::intrusive_ptr& opts) + c10::intrusive_ptr opts) : ProcessGroupNCCLSimulateErrors(store, rank, size, opts), set_timedout_error_(false) {} From e1ee3bfc0e50e093d95c5f0b69aa641e007a2991 Mon Sep 17 00:00:00 2001 From: anjali411 Date: Thu, 12 Nov 2020 07:55:33 -0800 Subject: [PATCH 50/93] Port bmm and baddbmm from TH to ATen (#42553) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/42553 Ports `torch.bmm` and `torch.baddbmm` from TH to ATen, as well as adds support for complex dtypes. Also removes dead TH code for Level 2 functions. Closes #24539 Test Plan: Imported from OSS Reviewed By: ansley Differential Revision: D24893511 Pulled By: anjali411 fbshipit-source-id: 0eba3f2aec99c48b3018a5264ee7789279cfab58 --- BUILD.bazel | 1 - aten/src/ATen/LegacyTHFunctionsCUDA.h | 4 - aten/src/ATen/NamedTensorUtils.cpp | 17 +- aten/src/ATen/NamedTensorUtils.h | 8 +- aten/src/ATen/core/aten_interned_strings.h | 3 - aten/src/ATen/cuda/CUDABlas.cpp | 205 +++++++++++ aten/src/ATen/cuda/CUDABlas.h | 37 +- aten/src/ATen/cuda/LegacyTHFunctionsCUDA.cpp | 329 ----------------- aten/src/ATen/native/cuda/LinearAlgebra.cu | 182 ++++++++-- aten/src/THC/CMakeLists.txt | 3 - aten/src/THC/THCBlas.cu | 359 ------------------- aten/src/THC/THCBlas.h | 25 -- aten/src/THC/THCTensorMath.h | 6 - aten/src/THC/THCTensorMathBlas.cu | 13 - aten/src/THC/generic/THCTensorMathBlas.cu | 326 ----------------- aten/src/THC/generic/THCTensorMathBlas.h | 7 - test/test_torch.py | 16 +- 17 files changed, 408 insertions(+), 1133 deletions(-) delete mode 100644 aten/src/THC/THCTensorMathBlas.cu delete mode 100644 aten/src/THC/generic/THCTensorMathBlas.cu delete mode 100644 aten/src/THC/generic/THCTensorMathBlas.h diff --git a/BUILD.bazel b/BUILD.bazel index 4ec99d770f70..7dc0e6d213fb 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -378,7 +378,6 @@ filegroup( "aten/src/THC/THCTensorCopy.cu.cc", "aten/src/THC/THCTensorIndex.cu.cc", "aten/src/THC/THCTensorMath.cu.cc", - "aten/src/THC/THCTensorMathBlas.cu.cc", "aten/src/THC/THCTensorMathMagma.cu.cc", "aten/src/THC/THCTensorMathPairwise.cu.cc", "aten/src/THC/THCTensorMathReduce.cu.cc", diff --git a/aten/src/ATen/LegacyTHFunctionsCUDA.h b/aten/src/ATen/LegacyTHFunctionsCUDA.h index 7b3be6db3d77..1ec33b675cbf 100644 --- a/aten/src/ATen/LegacyTHFunctionsCUDA.h +++ b/aten/src/ATen/LegacyTHFunctionsCUDA.h @@ -44,10 +44,6 @@ Tensor & _th_fmod_(Tensor & self, Scalar other); Tensor & _th_fmod_(Tensor & self, const Tensor & other); Tensor & _th_cross_kernel_out(Tensor & result, const Tensor & self, const Tensor & other, int64_t dim); Tensor _th_cross_kernel(const Tensor & self, const Tensor & other, int64_t dim); -Tensor & _th_bmm_out(Tensor & result, const Tensor & self, const Tensor & mat2); -Tensor _th_bmm(const Tensor & self, const Tensor & mat2); -Tensor & _th_baddbmm_out(Tensor & result, const Tensor & self, const Tensor & batch1, const Tensor & batch2, Scalar beta, Scalar alpha); -Tensor _th_baddbmm(const Tensor & self, const Tensor & batch1, const Tensor & batch2, Scalar beta, Scalar alpha); std::tuple _th_gels_out(Tensor & res1, Tensor & res2, const Tensor & self, const Tensor & A); std::tuple _th_gels(const Tensor & self, const Tensor & A); std::tuple _th_eig_out(Tensor & res1, Tensor & res2, const Tensor & self, bool eigenvectors); diff --git a/aten/src/ATen/NamedTensorUtils.cpp b/aten/src/ATen/NamedTensorUtils.cpp index f59cbed39abb..668838877123 100644 --- a/aten/src/ATen/NamedTensorUtils.cpp +++ b/aten/src/ATen/NamedTensorUtils.cpp @@ -517,17 +517,16 @@ std::vector compute_bmm_outnames( } std::vector compute_baddbmm_outnames( - TensorImpl* result, - TensorImpl* batch1, - TensorImpl* batch2, - TensorImpl* bias) { - if (!impl::has_names(result) && !impl::has_names(batch1) && - !impl::has_names(batch2) && !impl::has_names(bias)) { + Tensor& result, + const Tensor& self, + const Tensor& other, + const Tensor& bias) { + if (!result.has_names() && !self.has_names() + && !other.has_names() && !bias.has_names()) { return {}; } - auto bmm_names = compute_matmul_outnames( - impl::get_names(batch1), impl::get_names(batch2)); - auto baddbmm_names = unify_from_right(impl::get_names(bias), bmm_names); + auto bmm_names = compute_matmul_outnames(self.names(), other.names()); + auto baddbmm_names = unify_from_right(bias.names(), bmm_names); return baddbmm_names; } diff --git a/aten/src/ATen/NamedTensorUtils.h b/aten/src/ATen/NamedTensorUtils.h index 6777f39f7fcf..47dfd580a189 100644 --- a/aten/src/ATen/NamedTensorUtils.h +++ b/aten/src/ATen/NamedTensorUtils.h @@ -155,10 +155,10 @@ CAFFE2_API void propagate_names_for_addmv( CAFFE2_API void check_names_for_dot(TensorImpl* vec1, TensorImpl* vec2); CAFFE2_API std::vector compute_baddbmm_outnames( - TensorImpl* result, - TensorImpl* self, - TensorImpl* other, - TensorImpl* bias); + Tensor& result, + const Tensor& self, + const Tensor& other, + const Tensor& bias); CAFFE2_API bool are_names_equal(TensorImpl* self, TensorImpl* other); diff --git a/aten/src/ATen/core/aten_interned_strings.h b/aten/src/ATen/core/aten_interned_strings.h index 4a1aa4e9f0d2..267140f5d90c 100644 --- a/aten/src/ATen/core/aten_interned_strings.h +++ b/aten/src/ATen/core/aten_interned_strings.h @@ -133,8 +133,6 @@ _(aten, _sum_cuda) \ _(aten, _tan) \ _(aten, _tanh) \ _(aten, _tanh_forward) \ -_(aten, _th_baddbmm) \ -_(aten, _th_bmm) \ _(aten, _th_get_device) \ _(aten, _th_kthvalue) \ _(aten, _th_mode) \ @@ -669,7 +667,6 @@ _(aten, tanh) \ _(aten, tensor) \ _(aten, tensordot) \ _(aten, tensor_split) \ -_(aten, th_addmm) \ _(aten, th_clone) \ _(aten, th_norm) \ _(aten, th_pow) \ diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp index d4b31401f31f..8c32c8db1a1c 100644 --- a/aten/src/ATen/cuda/CUDABlas.cpp +++ b/aten/src/ATen/cuda/CUDABlas.cpp @@ -133,6 +133,56 @@ const char* _cublasGetErrorEnum(cublasStatus_t error) { /* LEVEL 3 BLAS FUNCTIONS */ +#ifndef __HIP_PLATFORM_HCC__ +#if defined(CUDA_VERSION) && CUDA_VERSION >= 11200 +#define cublasGemmStridedBatchedExFix cublasGemmStridedBatchedEx +#else +// Workaround for https://github.com/pytorch/pytorch/issues/45724 +cublasStatus_t cublasGemmStridedBatchedExFix(cublasHandle_t &handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const void *alpha, + const void *A, + cudaDataType Atype, + int lda, + long long int strideA, + const void *B, + cudaDataType Btype, + int ldb, + long long int strideB, + const void *beta, + void *C, + cudaDataType Ctype, + int ldc, + long long int strideC, + int64_t batchCount, + cudaDataType computeType, + cublasGemmAlgo_t algo) +{ + cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties(); + if (prop->major != 7) { + return cublasGemmStridedBatchedEx(handle, transa, transb, m, n, k, alpha, A, Atype, lda, strideA, B, Btype, ldb, strideB, beta, C, Ctype, ldc, strideC, batchCount, computeType, algo); + } + cublasStatus_t result; + constexpr int64_t split = 63 * 1024; + for(int64_t i = 0; i < batchCount; i += split) { + int64_t count = std::min(split, batchCount - i); + result = cublasGemmStridedBatchedEx(handle, transa, transb, m, n, k, alpha, + (char *)A + i * strideA * 2, Atype, lda, strideA, + (char *)B + i * strideB * 2, Btype, ldb, strideB, + beta, + (char *)C + i * strideC * 2, Ctype, ldc, strideC, + (int)count, computeType, algo); + TORCH_CUDABLAS_CHECK(result); + } + return result; +} +#endif +#endif + #define GEMM_CHECK_ARGVALUES(Dtype) \ do { \ CUDABLAS_NONNEGINT_CHECK(gemm, m); \ @@ -143,6 +193,161 @@ const char* _cublasGetErrorEnum(cublasStatus_t error) { CUDABLAS_POSINT_CHECK(gemm, ldc); \ } while (0) +#define BGEMM_CHECK_ARGVALUES(Dtype) \ + do { \ + CUDABLAS_NONNEGINT_CHECK(bgemm, m); \ + CUDABLAS_NONNEGINT_CHECK(bgemm, n); \ + CUDABLAS_NONNEGINT_CHECK(bgemm, k); \ + CUDABLAS_POSINT_CHECK(bgemm, lda); \ + CUDABLAS_POSINT_CHECK(bgemm, ldb); \ + CUDABLAS_POSINT_CHECK(bgemm, ldc); \ + CUDABLAS_NONNEGINT_CHECK(bgemm, num_batches); \ + } while (0) + +template <> +void bgemm(CUDABLAS_BGEMM_ARGTYPES(double)) { + // See Note [Writing Nondeterministic Operations] + globalContext().alertCuBLASConfigNotDeterministic(); + cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); + cublasOperation_t opa = _cublasOpFromChar(transa); + cublasOperation_t opb = _cublasOpFromChar(transb); + _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); + BGEMM_CHECK_ARGVALUES(double); + TORCH_CUDABLAS_CHECK(cublasDgemmStridedBatched( + handle, opa, opb, m, n, k, &alpha, a, lda, stridea, b, ldb, strideb, &beta, c, ldc, stridec, num_batches)); +} + +template <> +void bgemm(CUDABLAS_BGEMM_ARGTYPES(float)) { + // See Note [Writing Nondeterministic Operations] + globalContext().alertCuBLASConfigNotDeterministic(); + cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); + cublasOperation_t opa = _cublasOpFromChar(transa); + cublasOperation_t opb = _cublasOpFromChar(transb); + _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); + BGEMM_CHECK_ARGVALUES(float); + TORCH_CUDABLAS_CHECK(cublasSgemmStridedBatched( + handle, opa, opb, m, n, k, &alpha, a, lda, stridea, b, ldb, strideb, &beta, c, ldc, stridec, num_batches)); +} + +template <> +void bgemm>(CUDABLAS_BGEMM_ARGTYPES(c10::complex)) { + // See Note [Writing Nondeterministic Operations] + globalContext().alertCuBLASConfigNotDeterministic(); + cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); + cublasOperation_t opa = _cublasOpFromChar(transa); + cublasOperation_t opb = _cublasOpFromChar(transb); + _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); + BGEMM_CHECK_ARGVALUES(c10::complex); + TORCH_CUDABLAS_CHECK(cublasZgemmStridedBatched( + handle, opa, opb, m, n, k, reinterpret_cast(&alpha), reinterpret_cast(a), + lda, stridea, reinterpret_cast(b), ldb, strideb, reinterpret_cast(&beta), + reinterpret_cast(c), ldc, stridec, num_batches)); +} + +template <> +void bgemm>(CUDABLAS_BGEMM_ARGTYPES(c10::complex)) { + // See Note [Writing Nondeterministic Operations] + globalContext().alertCuBLASConfigNotDeterministic(); + cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); + cublasOperation_t opa = _cublasOpFromChar(transa); + cublasOperation_t opb = _cublasOpFromChar(transb); + _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); + BGEMM_CHECK_ARGVALUES(c10::complex); + TORCH_CUDABLAS_CHECK(cublasCgemmStridedBatched( + handle, opa, opb, m, n, k, reinterpret_cast(&alpha), reinterpret_cast(a), + lda, stridea, reinterpret_cast(b), ldb, strideb, reinterpret_cast(&beta), + reinterpret_cast(c), ldc, stridec, num_batches)); +} + +template <> +void bgemm(CUDABLAS_BGEMM_ARGTYPES(at::Half)) { + // See Note [Writing Nondeterministic Operations] + globalContext().alertCuBLASConfigNotDeterministic(); + cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); + cublasOperation_t opa = _cublasOpFromChar(transa); + cublasOperation_t opb = _cublasOpFromChar(transb); + _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); + BGEMM_CHECK_ARGVALUES(at::Half); + float falpha = alpha; + float fbeta = beta; +#ifdef __HIP_PLATFORM_HCC__ + TORCH_CUDABLAS_CHECK(rocblas_gemm_strided_batched_ex(handle, opa, opb, (int)m, (int)n, (int)k, + (void*)&falpha, a, rocblas_datatype_f16_r, (int)lda, stridea, + b, rocblas_datatype_f16_r, (int)ldb, strideb, + (void*)&fbeta, c, rocblas_datatype_f16_r, (int)ldc, stridec, + c, rocblas_datatype_f16_r, (int)ldc, stridec, + (int) num_batches, rocblas_datatype_f32_r, rocblas_gemm_algo_standard, + 0, 0)); +#else + #if defined(CUDA_VERSION) && CUDA_VERSION < 11000 + // On CUDA versions prior to 11, users are required to set the math mode to CUBLAS_TENSOR_OP_MATH + // manually to be able to use tensor cores for FP16. On CUDA 11, this is no longer required. + TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); + #endif // CUDA_VERSION < 11000 + + cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties(); + if (prop->major >= 5){ + TORCH_CUDABLAS_CHECK(cublasGemmStridedBatchedExFix( + handle, opa, opb, m, n, k, + (void*)(&falpha), a, CUDA_R_16F, lda, stridea, + b, CUDA_R_16F, ldb, strideb, (void*)(&fbeta), + c, CUDA_R_16F, ldc, stridec, + num_batches, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + } else { + for (int64_t i = 0; i < num_batches; ++i) { + at::cuda::blas::gemm( + transa, transb, + m, n, k, + alpha, (a + i * stridea), lda, + (b + i * strideb), ldb, beta, + (c + i * stridec), ldc); + } + } + #if defined(CUDA_VERSION) && CUDA_VERSION < 11000 + // On CUDA versions prior to 11, users are required to set the math mode to CUBLAS_TENSOR_OP_MATH + // manually to be able to use tensor cores for FP16. On CUDA 11, this is no longer required. + TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); + #endif // CUDA_VERSION < 11000 +#endif // __HIP_PLATFORM_HCC__ +} + +#if defined(__HIP_PLATFORM_HCC__) || defined(CUDA_VERSION) && CUDA_VERSION >= 11000 +template <> +void bgemm(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) { + // See Note [Writing Nondeterministic Operations] + globalContext().alertCuBLASConfigNotDeterministic(); + BGEMM_CHECK_ARGVALUES(at::BFloat16); + cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); + cublasOperation_t opa = _cublasOpFromChar(transa); + cublasOperation_t opb = _cublasOpFromChar(transb); + float falpha = alpha; + float fbeta = beta; + _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); + + #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 + cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties(); + TORCH_CHECK(prop->major >= 8, "BFloat16 bgemm in CUDA requires Ampere or later GPU"); + TORCH_CUDABLAS_CHECK(cublasGemmStridedBatchedExFix(handle, + opa, opb, (int)m, (int)n, (int)k, + (void*)&falpha, a, CUDA_R_16BF, (int)lda, stridea, + b, CUDA_R_16BF, (int)ldb, strideb, + (void*)&fbeta, c, CUDA_R_16BF, (int)ldc, stridec, + (int)num_batches, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + #elif defined(__HIP_PLATFORM_HCC__) + TORCH_CUDABLAS_CHECK(rocblas_gemm_strided_batched_ex(handle, opa, opb, (int)m, (int)n, (int)k, + (void*)&falpha, a, rocblas_datatype_bf16_r, (int)lda, stridea, + b, rocblas_datatype_bf16_r, (int)ldb, strideb, + (void*)&fbeta, c, rocblas_datatype_bf16_r, (int)ldc, stridec, + c, rocblas_datatype_bf16_r, (int)ldc, stridec, + (int) num_batches, rocblas_datatype_f32_r, rocblas_gemm_algo_standard, + 0, 0, NULL, NULL)); + #else + TORCH_CHECK(false, "BFloat16 bgemm in CUDA requires Ampere or later GPU"); + #endif // defined(CUDA_VERSION) && CUDA_VERSION >= 11000 +} +#endif // __HIP_PLATFORM_HCC__ + template <> void gemm(CUDABLAS_GEMM_ARGTYPES(double)) { // See Note [Writing Nondeterministic Operations] diff --git a/aten/src/ATen/cuda/CUDABlas.h b/aten/src/ATen/cuda/CUDABlas.h index c5b4c43a27b1..93a0ff588dda 100644 --- a/aten/src/ATen/cuda/CUDABlas.h +++ b/aten/src/ATen/cuda/CUDABlas.h @@ -69,6 +69,31 @@ template <> void gemm(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)); #endif +#define CUDABLAS_BGEMM_ARGTYPES(Dtype) \ + char transa, char transb, int64_t m, int64_t n, int64_t k, Dtype alpha, \ + const Dtype *a, int64_t lda, int64_t stridea, \ + const Dtype *b, int64_t ldb, int64_t strideb, \ + Dtype beta, Dtype *c, int64_t ldc, int64_t stridec, int64_t num_batches + +template +inline void bgemm(CUDABLAS_BGEMM_ARGTYPES(Dtype)) { + AT_ERROR("at::cuda::blas::bgemm: not implemented for ", typeid(Dtype).name()); +} + +template <> +void bgemm(CUDABLAS_BGEMM_ARGTYPES(double)); +template <> +void bgemm(CUDABLAS_BGEMM_ARGTYPES(float)); +template <> +void bgemm>(CUDABLAS_BGEMM_ARGTYPES(c10::complex)); +template <> +void bgemm>(CUDABLAS_BGEMM_ARGTYPES(c10::complex)); +template <> +void bgemm(CUDABLAS_BGEMM_ARGTYPES(at::Half)); +#if defined(__HIP_PLATFORM_HCC__) || defined(CUDA_VERSION) && CUDA_VERSION >= 11000 +template <> +void bgemm(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)); +#endif /* LEVEL 2 BLAS FUNCTIONS */ #define CUDABLAS_GEMV_ARGTYPES(Dtype) \ @@ -97,18 +122,6 @@ template <> void gemv(CUDABLAS_GEMV_ARGTYPES(at::BFloat16)); #endif -template -void ger( - int64_t m, - int64_t n, - Dtype alpha, - Dtype* x, - int64_t incx, - Dtype* y, - int64_t incy, - Dtype* a, - int64_t lda); - /* LEVEL 1 BLAS FUNCTIONS */ #define CUDABLAS_DOT_ARGTYPES(Dtype) \ diff --git a/aten/src/ATen/cuda/LegacyTHFunctionsCUDA.cpp b/aten/src/ATen/cuda/LegacyTHFunctionsCUDA.cpp index 45ceddcd94e8..0aad275684a6 100644 --- a/aten/src/ATen/cuda/LegacyTHFunctionsCUDA.cpp +++ b/aten/src/ATen/cuda/LegacyTHFunctionsCUDA.cpp @@ -1536,336 +1536,7 @@ Tensor _th_cross_kernel(const Tensor & self, const Tensor & other, int64_t dim) } return result; } -Tensor & _th_bmm_out(Tensor & result, const Tensor & self, const Tensor & mat2) { - // DeviceGuard omitted - auto dispatch_scalar_type = infer_scalar_type(self); - switch (dispatch_scalar_type) { - case ScalarType::Byte: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_bmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_bmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto mat2_ = checked_dense_tensor_unwrap(mat2, "mat2", 2, "_th_bmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaByteTensor_baddbmm(globalContext().getTHCState(), result_, result_, self_, mat2_, uint8_t(0), uint8_t(1)); - break; - } - case ScalarType::Char: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_bmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_bmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto mat2_ = checked_dense_tensor_unwrap(mat2, "mat2", 2, "_th_bmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaCharTensor_baddbmm(globalContext().getTHCState(), result_, result_, self_, mat2_, int8_t(0), int8_t(1)); - break; - } - case ScalarType::Double: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_bmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_bmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto mat2_ = checked_dense_tensor_unwrap(mat2, "mat2", 2, "_th_bmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaDoubleTensor_baddbmm(globalContext().getTHCState(), result_, result_, self_, mat2_, double(0), double(1)); - break; - } - case ScalarType::Float: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_bmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_bmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto mat2_ = checked_dense_tensor_unwrap(mat2, "mat2", 2, "_th_bmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaTensor_baddbmm(globalContext().getTHCState(), result_, result_, self_, mat2_, float(0), float(1)); - break; - } - case ScalarType::Int: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_bmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_bmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto mat2_ = checked_dense_tensor_unwrap(mat2, "mat2", 2, "_th_bmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaIntTensor_baddbmm(globalContext().getTHCState(), result_, result_, self_, mat2_, int(0), int(1)); - break; - } - case ScalarType::Long: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_bmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_bmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto mat2_ = checked_dense_tensor_unwrap(mat2, "mat2", 2, "_th_bmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaLongTensor_baddbmm(globalContext().getTHCState(), result_, result_, self_, mat2_, int64_t(0), int64_t(1)); - break; - } - case ScalarType::Short: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_bmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_bmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto mat2_ = checked_dense_tensor_unwrap(mat2, "mat2", 2, "_th_bmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaShortTensor_baddbmm(globalContext().getTHCState(), result_, result_, self_, mat2_, int16_t(0), int16_t(1)); - break; - } - case ScalarType::Half: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_bmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_bmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto mat2_ = checked_dense_tensor_unwrap(mat2, "mat2", 2, "_th_bmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaHalfTensor_baddbmm(globalContext().getTHCState(), result_, result_, self_, mat2_, Half(0), Half(1)); - break; - } - case ScalarType::BFloat16: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_bmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_bmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto mat2_ = checked_dense_tensor_unwrap(mat2, "mat2", 2, "_th_bmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaBFloat16Tensor_baddbmm(globalContext().getTHCState(), result_, result_, self_, mat2_, BFloat16(0), BFloat16(1)); - break; - } - default: - AT_ERROR("_th_bmm_out not supported on CUDAType for ", dispatch_scalar_type); - } - return result; -} -Tensor _th_bmm(const Tensor & self, const Tensor & mat2) { - // DeviceGuard omitted - auto dispatch_scalar_type = infer_scalar_type(self); - auto result_ = c10::make_intrusive(c10::Storage(c10::Storage::use_byte_size_t(), 0, allocator(), true),DispatchKey::CUDA, scalarTypeToTypeMeta(dispatch_scalar_type)).release(); - auto result = Tensor(c10::intrusive_ptr::reclaim(result_)); - switch (dispatch_scalar_type) { - case ScalarType::Byte: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_bmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto mat2_ = checked_dense_tensor_unwrap(mat2, "mat2", 2, "_th_bmm", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaByteTensor_baddbmm(globalContext().getTHCState(), result_, result_, self_, mat2_, uint8_t(0), uint8_t(1)); - break; - } - case ScalarType::Char: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_bmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto mat2_ = checked_dense_tensor_unwrap(mat2, "mat2", 2, "_th_bmm", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaCharTensor_baddbmm(globalContext().getTHCState(), result_, result_, self_, mat2_, int8_t(0), int8_t(1)); - break; - } - case ScalarType::Double: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_bmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto mat2_ = checked_dense_tensor_unwrap(mat2, "mat2", 2, "_th_bmm", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaDoubleTensor_baddbmm(globalContext().getTHCState(), result_, result_, self_, mat2_, double(0), double(1)); - break; - } - case ScalarType::Float: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_bmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto mat2_ = checked_dense_tensor_unwrap(mat2, "mat2", 2, "_th_bmm", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaTensor_baddbmm(globalContext().getTHCState(), result_, result_, self_, mat2_, float(0), float(1)); - break; - } - case ScalarType::Int: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_bmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto mat2_ = checked_dense_tensor_unwrap(mat2, "mat2", 2, "_th_bmm", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaIntTensor_baddbmm(globalContext().getTHCState(), result_, result_, self_, mat2_, int(0), int(1)); - break; - } - case ScalarType::Long: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_bmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto mat2_ = checked_dense_tensor_unwrap(mat2, "mat2", 2, "_th_bmm", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaLongTensor_baddbmm(globalContext().getTHCState(), result_, result_, self_, mat2_, int64_t(0), int64_t(1)); - break; - } - case ScalarType::Short: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_bmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto mat2_ = checked_dense_tensor_unwrap(mat2, "mat2", 2, "_th_bmm", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaShortTensor_baddbmm(globalContext().getTHCState(), result_, result_, self_, mat2_, int16_t(0), int16_t(1)); - break; - } - case ScalarType::Half: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_bmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto mat2_ = checked_dense_tensor_unwrap(mat2, "mat2", 2, "_th_bmm", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaHalfTensor_baddbmm(globalContext().getTHCState(), result_, result_, self_, mat2_, Half(0), Half(1)); - break; - } - case ScalarType::BFloat16: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_bmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto mat2_ = checked_dense_tensor_unwrap(mat2, "mat2", 2, "_th_bmm", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaBFloat16Tensor_baddbmm(globalContext().getTHCState(), result_, result_, self_, mat2_, BFloat16(0), BFloat16(1)); - break; - } - default: - AT_ERROR("_th_bmm not supported on CUDAType for ", dispatch_scalar_type); - } - return result; -} -Tensor & _th_baddbmm_out(Tensor & result, const Tensor & self, const Tensor & batch1, const Tensor & batch2, Scalar beta, Scalar alpha) { - // DeviceGuard omitted - auto dispatch_scalar_type = infer_scalar_type(self); - - switch (dispatch_scalar_type) { - case ScalarType::Byte: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch1_ = checked_dense_tensor_unwrap(batch1, "batch1", 2, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch2_ = checked_dense_tensor_unwrap(batch2, "batch2", 3, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto beta_ = beta.toByte(); - auto alpha_ = alpha.toByte(); - THCudaByteTensor_baddbmm(globalContext().getTHCState(), result_, self_, batch1_, batch2_, beta_, alpha_); - break; - } - case ScalarType::Char: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch1_ = checked_dense_tensor_unwrap(batch1, "batch1", 2, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch2_ = checked_dense_tensor_unwrap(batch2, "batch2", 3, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto beta_ = beta.toChar(); - auto alpha_ = alpha.toChar(); - THCudaCharTensor_baddbmm(globalContext().getTHCState(), result_, self_, batch1_, batch2_, beta_, alpha_); - break; - } - case ScalarType::Double: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch1_ = checked_dense_tensor_unwrap(batch1, "batch1", 2, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch2_ = checked_dense_tensor_unwrap(batch2, "batch2", 3, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto beta_ = beta.toDouble(); - auto alpha_ = alpha.toDouble(); - THCudaDoubleTensor_baddbmm(globalContext().getTHCState(), result_, self_, batch1_, batch2_, beta_, alpha_); - break; - } - case ScalarType::Float: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch1_ = checked_dense_tensor_unwrap(batch1, "batch1", 2, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch2_ = checked_dense_tensor_unwrap(batch2, "batch2", 3, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto beta_ = beta.toFloat(); - auto alpha_ = alpha.toFloat(); - THCudaTensor_baddbmm(globalContext().getTHCState(), result_, self_, batch1_, batch2_, beta_, alpha_); - break; - } - case ScalarType::Int: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch1_ = checked_dense_tensor_unwrap(batch1, "batch1", 2, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch2_ = checked_dense_tensor_unwrap(batch2, "batch2", 3, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto beta_ = beta.toInt(); - auto alpha_ = alpha.toInt(); - THCudaIntTensor_baddbmm(globalContext().getTHCState(), result_, self_, batch1_, batch2_, beta_, alpha_); - break; - } - case ScalarType::Long: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch1_ = checked_dense_tensor_unwrap(batch1, "batch1", 2, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch2_ = checked_dense_tensor_unwrap(batch2, "batch2", 3, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto beta_ = beta.toLong(); - auto alpha_ = alpha.toLong(); - THCudaLongTensor_baddbmm(globalContext().getTHCState(), result_, self_, batch1_, batch2_, beta_, alpha_); - break; - } - case ScalarType::Short: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch1_ = checked_dense_tensor_unwrap(batch1, "batch1", 2, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch2_ = checked_dense_tensor_unwrap(batch2, "batch2", 3, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto beta_ = beta.toShort(); - auto alpha_ = alpha.toShort(); - THCudaShortTensor_baddbmm(globalContext().getTHCState(), result_, self_, batch1_, batch2_, beta_, alpha_); - break; - } - case ScalarType::Half: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch1_ = checked_dense_tensor_unwrap(batch1, "batch1", 2, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch2_ = checked_dense_tensor_unwrap(batch2, "batch2", 3, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto beta_ = beta.toHalf(); - auto alpha_ = alpha.toHalf(); - THCudaHalfTensor_baddbmm(globalContext().getTHCState(), result_, self_, batch1_, batch2_, beta_, alpha_); - break; - } - case ScalarType::BFloat16: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch1_ = checked_dense_tensor_unwrap(batch1, "batch1", 2, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch2_ = checked_dense_tensor_unwrap(batch2, "batch2", 3, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto beta_ = beta.toBFloat16(); - auto alpha_ = alpha.toBFloat16(); - THCudaBFloat16Tensor_baddbmm(globalContext().getTHCState(), result_, self_, batch1_, batch2_, beta_, alpha_); - break; - } - default: - AT_ERROR("_th_baddbmm_out not supported on CUDAType for ", dispatch_scalar_type); - } - return result; -} -Tensor _th_baddbmm(const Tensor & self, const Tensor & batch1, const Tensor & batch2, Scalar beta, Scalar alpha) { - // DeviceGuard omitted - auto dispatch_scalar_type = infer_scalar_type(self); - auto result_ = c10::make_intrusive(c10::Storage(c10::Storage::use_byte_size_t(), 0, allocator(), true),DispatchKey::CUDA, scalarTypeToTypeMeta(dispatch_scalar_type)).release(); - auto result = Tensor(c10::intrusive_ptr::reclaim(result_)); - switch (dispatch_scalar_type) { - case ScalarType::Byte: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_baddbmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch1_ = checked_dense_tensor_unwrap(batch1, "batch1", 2, "_th_baddbmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch2_ = checked_dense_tensor_unwrap(batch2, "batch2", 3, "_th_baddbmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto beta_ = beta.toByte(); - auto alpha_ = alpha.toByte(); - THCudaByteTensor_baddbmm(globalContext().getTHCState(), result_, self_, batch1_, batch2_, beta_, alpha_); - break; - } - case ScalarType::Char: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_baddbmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch1_ = checked_dense_tensor_unwrap(batch1, "batch1", 2, "_th_baddbmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch2_ = checked_dense_tensor_unwrap(batch2, "batch2", 3, "_th_baddbmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto beta_ = beta.toChar(); - auto alpha_ = alpha.toChar(); - THCudaCharTensor_baddbmm(globalContext().getTHCState(), result_, self_, batch1_, batch2_, beta_, alpha_); - break; - } - case ScalarType::Double: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_baddbmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch1_ = checked_dense_tensor_unwrap(batch1, "batch1", 2, "_th_baddbmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch2_ = checked_dense_tensor_unwrap(batch2, "batch2", 3, "_th_baddbmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto beta_ = beta.toDouble(); - auto alpha_ = alpha.toDouble(); - THCudaDoubleTensor_baddbmm(globalContext().getTHCState(), result_, self_, batch1_, batch2_, beta_, alpha_); - break; - } - case ScalarType::Float: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_baddbmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch1_ = checked_dense_tensor_unwrap(batch1, "batch1", 2, "_th_baddbmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch2_ = checked_dense_tensor_unwrap(batch2, "batch2", 3, "_th_baddbmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto beta_ = beta.toFloat(); - auto alpha_ = alpha.toFloat(); - THCudaTensor_baddbmm(globalContext().getTHCState(), result_, self_, batch1_, batch2_, beta_, alpha_); - break; - } - case ScalarType::Int: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_baddbmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch1_ = checked_dense_tensor_unwrap(batch1, "batch1", 2, "_th_baddbmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch2_ = checked_dense_tensor_unwrap(batch2, "batch2", 3, "_th_baddbmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto beta_ = beta.toInt(); - auto alpha_ = alpha.toInt(); - THCudaIntTensor_baddbmm(globalContext().getTHCState(), result_, self_, batch1_, batch2_, beta_, alpha_); - break; - } - case ScalarType::Long: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_baddbmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch1_ = checked_dense_tensor_unwrap(batch1, "batch1", 2, "_th_baddbmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch2_ = checked_dense_tensor_unwrap(batch2, "batch2", 3, "_th_baddbmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto beta_ = beta.toLong(); - auto alpha_ = alpha.toLong(); - THCudaLongTensor_baddbmm(globalContext().getTHCState(), result_, self_, batch1_, batch2_, beta_, alpha_); - break; - } - case ScalarType::Short: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_baddbmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch1_ = checked_dense_tensor_unwrap(batch1, "batch1", 2, "_th_baddbmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch2_ = checked_dense_tensor_unwrap(batch2, "batch2", 3, "_th_baddbmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto beta_ = beta.toShort(); - auto alpha_ = alpha.toShort(); - THCudaShortTensor_baddbmm(globalContext().getTHCState(), result_, self_, batch1_, batch2_, beta_, alpha_); - break; - } - case ScalarType::Half: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_baddbmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch1_ = checked_dense_tensor_unwrap(batch1, "batch1", 2, "_th_baddbmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch2_ = checked_dense_tensor_unwrap(batch2, "batch2", 3, "_th_baddbmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto beta_ = beta.toHalf(); - auto alpha_ = alpha.toHalf(); - THCudaHalfTensor_baddbmm(globalContext().getTHCState(), result_, self_, batch1_, batch2_, beta_, alpha_); - break; - } - case ScalarType::BFloat16: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_baddbmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch1_ = checked_dense_tensor_unwrap(batch1, "batch1", 2, "_th_baddbmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch2_ = checked_dense_tensor_unwrap(batch2, "batch2", 3, "_th_baddbmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto beta_ = beta.toBFloat16(); - auto alpha_ = alpha.toBFloat16(); - THCudaBFloat16Tensor_baddbmm(globalContext().getTHCState(), result_, self_, batch1_, batch2_, beta_, alpha_); - break; - } - default: - AT_ERROR("_th_baddbmm not supported on CUDAType for ", dispatch_scalar_type); - } - return result; -} std::tuple _th_gels_out(Tensor & res1, Tensor & res2, const Tensor & self, const Tensor & A) { // DeviceGuard omitted auto dispatch_scalar_type = infer_scalar_type(self); diff --git a/aten/src/ATen/native/cuda/LinearAlgebra.cu b/aten/src/ATen/native/cuda/LinearAlgebra.cu index 3bb9cea5e5cc..95998790d093 100644 --- a/aten/src/ATen/native/cuda/LinearAlgebra.cu +++ b/aten/src/ATen/native/cuda/LinearAlgebra.cu @@ -5,32 +5,6 @@ namespace at { namespace native { -Tensor baddbmm_cuda(const Tensor& self, const Tensor& batch1, const Tensor& batch2, Scalar beta, Scalar alpha) { - Tensor b_self; - std::tie(b_self) = expand_size(self, {batch1.size(0), batch1.size(1), batch2.size(2)}, "baddbmm"); - return legacy::cuda::_th_baddbmm(b_self, batch1, batch2, beta, alpha); -} - -Tensor& baddbmm_out_cuda(Tensor &result, const Tensor& self, const Tensor& batch1, const Tensor& batch2, Scalar beta, Scalar alpha) { - Tensor b_self; - std::tie(b_self) = expand_size(self, {batch1.size(0), batch1.size(1), batch2.size(2)}, "baddbmm_out"); - return legacy::cuda::_th_baddbmm_out(result, b_self, batch1, batch2, beta, alpha); -} - -Tensor& baddbmm__cuda(Tensor& self, const Tensor& batch1, const Tensor& batch2, Scalar beta, Scalar alpha) { - return baddbmm_out_cuda(self, self, batch1, batch2, beta, alpha); -} - -Tensor& bmm_out_cuda(Tensor &result, const Tensor& batch1, const Tensor& batch2) { - result.resize_({ batch1.size(0), batch1.size(1), batch2.size(2) }); - return legacy::cuda::_th_bmm_out(result, batch1, batch2); -} - -Tensor bmm_cuda(const Tensor& self, const Tensor& mat2) { - Tensor result = at::empty({0}, self.options()); - return native::bmm_out_cuda(result, self, mat2); -} - Tensor prepare_matrix_for_cublas(Tensor& tensor, bool& transpose_tensor) { Tensor tensor_; IntArrayRef tensor_strides = tensor.strides(); @@ -50,6 +24,35 @@ Tensor prepare_matrix_for_cublas(Tensor& tensor, bool& transpose_tensor) { return tensor_; } +Tensor prepare_batch_matrix_for_cublas(const Tensor& tensor, bool& transpose_tensor, int64_t& ld_tensor, bool transpose_result, int64_t m, int64_t n) { + IntArrayRef tensor_strides = tensor.strides(); + Tensor tensor_; + int fast_dim = transpose_result ? 2 : 1; + int leading_dim = transpose_result ? 1 : 2; + + if (tensor_strides[fast_dim] == 1 && + (tensor_strides[leading_dim] >= std::max(1, m))) { + transpose_tensor = false; + tensor_ = tensor; + ld_tensor = tensor_strides[leading_dim]; + } else if ((tensor_strides[leading_dim] == 1) && + (tensor_strides[fast_dim] >= std::max(1, n))) { + transpose_tensor = true; + tensor_ = tensor; + ld_tensor = tensor_strides[fast_dim]; + } else { + transpose_tensor = !transpose_result; + if (tensor.is_contiguous()) { + tensor_ = tensor; + } else { + tensor_ = tensor.clone(at::MemoryFormat::Contiguous); + } + ld_tensor = tensor_.stride(1); + } + + return tensor_; +} + namespace { Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& mat1, const Tensor& mat2, Scalar beta, Scalar alpha) { @@ -142,6 +145,88 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma return result; } +Tensor& baddbmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& batch1, const Tensor& batch2, Scalar beta, Scalar alpha) { + TORCH_CHECK(self.dim() == 3, "self must be a 3D tensor"); + TORCH_CHECK(batch1.dim() == 3, "batch1 must be a 3D tensor"); + TORCH_CHECK(batch2.dim() == 3, "batch2 must be a 3D tensor"); + + TensorArg args[]{{result, "out", 0}, {self, "self", 1}, {batch1, "batch1", 2}, {batch2, "batch2", 3}}; + checkAllSameGPU("baddbmm", args); + + IntArrayRef batch1_sizes = batch1.sizes(); + IntArrayRef batch2_sizes = batch2.sizes(); + IntArrayRef self_sizes = self.sizes(); + + TORCH_CHECK(self_sizes[0] == batch1_sizes[0], "self dim 0 must match batch1 dim 0"); + TORCH_CHECK(self_sizes[0] == batch2_sizes[0], "self dim 0 must match batch2 dim 0"); + TORCH_CHECK(self_sizes[1] == batch1_sizes[1], "self dim 1 must match batch1 dim 1"); + TORCH_CHECK(self_sizes[2] == batch2_sizes[2], "self dim 2 must match batch2 dim 2"); + TORCH_CHECK(batch1_sizes[2] == batch2_sizes[1], "batch1 dim 2 must match batch2 dim 1"); + + if (!result.is_same(self)) { + result.resize_as_(self); + if (beta.to>() != 0.0) { + result.copy_(self); + } + } + + bool transpose_result = false; + Tensor result_; + IntArrayRef result_strides = result.strides(); + IntArrayRef result_sizes = result.sizes(); + + if ((result_strides[1] == 1) && + ((result_sizes[2] == 1) || (result_strides[2] >= std::max(1, result_sizes[1])))) { + result_ = result; + } else if ((result_strides[2] == 1) && + (result_sizes[1] == 1 || (result_strides[1] >= std::max(1, result_sizes[2])))) { + transpose_result = true; + result_ = result; + } else { + result_ = result.transpose(1, 2).clone(at::MemoryFormat::Contiguous); + result_ = result_.transpose(1, 2); + } + + int leading_dim = transpose_result ? 1 : 2; + + Tensor batch1_ = transpose_result ? batch2 : batch1; + Tensor batch2_ = transpose_result ? batch1 : batch2; + int64_t m = result_sizes[transpose_result ? 2 : 1]; + int64_t n = result_sizes[leading_dim]; + int64_t k = batch1_.size(leading_dim); + + int64_t lda, ldb, ldc; + bool transpose_batch1, transpose_batch2; + batch1_ = prepare_batch_matrix_for_cublas(batch1_, transpose_batch1, lda, transpose_result, m, k); + batch2_ = prepare_batch_matrix_for_cublas(batch2_, transpose_batch2, ldb, transpose_result, k, n); + + ldc = result_.stride(leading_dim); + int64_t num_batches = result_.size(0); + + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "baddbmm_cuda", [&] { + scalar_t alpha_val = alpha.to(); + scalar_t beta_val = beta.to(); + scalar_t* batch1_ptr = batch1_.data_ptr(); + scalar_t* batch2_ptr = batch2_.data_ptr(); + scalar_t* result_ptr = result_.data_ptr(); + at::cuda::blas::bgemm( + transpose_batch1 ? 't' : 'n', + transpose_batch2 ? 't' : 'n', + m, n, k, + alpha_val, + batch1_ptr, lda, batch1_.stride(0), + batch2_ptr, ldb, batch2_.stride(0), + beta_val, + result_ptr, ldc, result_.stride(0), + num_batches + ); + }); + if (!result.is_same(result_)) { + result.copy_(result_); + } + return result; +} + } // anonymous namespace Tensor& mm_out_cuda(Tensor& result, const Tensor& self, const Tensor& mat2) { @@ -178,6 +263,51 @@ Tensor& addmm__cuda(Tensor& self, const Tensor& mat1, const Tensor& mat2, return self; } +Tensor& baddbmm_out_cuda(Tensor &result, const Tensor& self, const Tensor& batch1, const Tensor& batch2, Scalar beta, Scalar alpha) { + Tensor self_; + if (&result != &self) { + std::tie(self_) = expand_size(self, {batch1.size(0), batch1.size(1), batch2.size(2)}, "baddbmm"); + } else { + self_ = self; + } + { + at::NoNamesGuard guard; + baddbmm_out_cuda_impl(result, self_, batch1, batch2, beta, alpha); + } + namedinference::propagate_names_if_nonempty( + result, + namedinference::compute_baddbmm_outnames(result, batch1, batch2, self)); + return result; +} + +Tensor baddbmm_cuda(const Tensor& self, const Tensor& batch1, const Tensor& batch2, Scalar beta, Scalar alpha) { + Tensor out = at::empty({0}, self.options()); + return baddbmm_out_cuda(out, self, batch1, batch2, beta, alpha); +} + +Tensor& baddbmm__cuda(Tensor& self, const Tensor& batch1, const Tensor& batch2, Scalar beta, Scalar alpha) { + return baddbmm_out_cuda(self, self, batch1, batch2, beta, alpha); +} + +Tensor& bmm_out_cuda(Tensor &result, const Tensor& batch1, const Tensor& batch2) { + result.resize_({ batch1.size(0), batch1.size(1), batch2.size(2) }); + Scalar beta(0.0); + Scalar alpha(1.0); + { + NoNamesGuard guard; + baddbmm_out_cuda_impl(result, result, batch1, batch2, beta, alpha); + } + namedinference::propagate_names_if_nonempty( + result, + namedinference::compute_bmm_outnames(result, batch1, batch2)); + return result; +} + +Tensor bmm_cuda(const Tensor& self, const Tensor& mat2) { + Tensor result = at::empty({0}, self.options()); + return native::bmm_out_cuda(result, self, mat2); +} + Tensor& addbmm_out_cuda(Tensor& out, const Tensor& self, const Tensor& batch1, const Tensor& batch2, Scalar beta, Scalar alpha) { diff --git a/aten/src/THC/CMakeLists.txt b/aten/src/THC/CMakeLists.txt index bee2f5b84e50..4ba4a4ce4456 100644 --- a/aten/src/THC/CMakeLists.txt +++ b/aten/src/THC/CMakeLists.txt @@ -48,7 +48,6 @@ set(ATen_CUDA_SRCS ${ATen_CUDA_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/THCTensor.cu ${CMAKE_CURRENT_SOURCE_DIR}/THCTensorCopy.cu ${CMAKE_CURRENT_SOURCE_DIR}/THCTensorMath.cu - ${CMAKE_CURRENT_SOURCE_DIR}/THCTensorMathBlas.cu ${CMAKE_CURRENT_SOURCE_DIR}/THCTensorMathMagma.cu ${CMAKE_CURRENT_SOURCE_DIR}/THCTensorMathPairwise.cu ${CMAKE_CURRENT_SOURCE_DIR}/THCTensorMathReduce.cu @@ -141,8 +140,6 @@ install(FILES generic/THCTensorMasked.cu generic/THCTensorMath.h generic/THCTensorMath.cu - generic/THCTensorMathBlas.cu - generic/THCTensorMathBlas.h generic/THCTensorMathMagma.h generic/THCTensorMathMagma.cu generic/THCTensorMathPairwise.h diff --git a/aten/src/THC/THCBlas.cu b/aten/src/THC/THCBlas.cu index 3f16eec6df60..99ee29d18766 100644 --- a/aten/src/THC/THCBlas.cu +++ b/aten/src/THC/THCBlas.cu @@ -11,113 +11,12 @@ #include #endif -/* Level 2 */ - -void adjustLdLevel2(int64_t m, int64_t n, int64_t *lda) -{ - // Note: leading dimensions generally are checked that they are > 0 and at least as big the result - // requires (even if the value won't be used). - // TODO: why does Level3 check trans but this doesn't? - if (n <= 1) - *lda = std::max(m, 1); -} - -void THCudaBlas_Sger(THCState *state, int64_t m, int64_t n, float alpha, float *x, int64_t incx, float *y, int64_t incy, float *a, int64_t lda) -{ - adjustLdLevel2(m, n, &lda); - - if( (m <= INT_MAX) && (n <= INT_MAX) && (lda <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) ) - { - int i_m = (int)m; - int i_n = (int)n; - int i_lda = (int)lda; - int i_incx = (int)incx; - int i_incy = (int)incy; - - cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); - THCublasCheck(cublasSger(handle, i_m, i_n, &alpha, x, i_incx, y, i_incy, a, i_lda)); - return; - } - THError("Cublas_Sger only supports m, n, lda, incx, incy" - "with the bound [val] <= %d", INT_MAX); -} - -void THCudaBlas_Dger(THCState *state, int64_t m, int64_t n, double alpha, double *x, int64_t incx, double *y, int64_t incy, double *a, int64_t lda) -{ - adjustLdLevel2(m, n, &lda); - - if( (m <= INT_MAX) && (n <= INT_MAX) && (lda <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) ) - { - int i_m = (int)m; - int i_n = (int)n; - int i_lda = (int)lda; - int i_incx = (int)incx; - int i_incy = (int)incy; - - cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); - THCublasCheck(cublasDger(handle, i_m, i_n, &alpha, x, i_incx, y, i_incy, a, i_lda)); - return; - } - THError("Cublas_Dger only supports m, n, lda, incx, incy" - "with the bound [val] <= %d", INT_MAX); -} - - -cublasOperation_t convertTransToCublasOperation(char trans) { - if (trans == 't') return CUBLAS_OP_T; - else if (trans == 'n') return CUBLAS_OP_N; - else if (trans == 'c') return CUBLAS_OP_C; - else { - THError("trans must be one of: t, n, c"); - return CUBLAS_OP_T; - } -} - -void adjustLdLevel3(char transa, char transb, int64_t m, int64_t n, int64_t k, int64_t *lda, int64_t *ldb, int64_t *ldc) -{ - int transa_ = ((transa == 't') || (transa == 'T')); - int transb_ = ((transb == 't') || (transb == 'T')); - - // Note: leading dimensions generally are checked that they are > 0 and at least as big the result - // requires (even if the value won't be used). - if(n <= 1) - *ldc = std::max(m, 1); - - if(transa_) - { - if(m <= 1) - *lda = std::max(k, 1); - } - else - { - if(k <= 1) - *lda = std::max(m, 1); - } - - if(transb_) - { - if(k <= 1) - *ldb = std::max(n, 1); - } - else - { - if(n <= 1) - *ldb = std::max(k, 1); - } - -} - /* Level 3 */ void THCudaBlas_Sgemm(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k, float alpha, float *a, int64_t lda, float *b, int64_t ldb, float beta, float *c, int64_t ldc) { at::cuda::blas::gemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } -// In CUDA 8.0, definition of data types for sgemmex changed -#if CUDA_VERSION < 8000 -# define CUDA_R_16F CUBLAS_DATA_HALF -#endif - void THCudaBlas_Hgemm(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k, at::Half alpha, at::Half *a, int64_t lda, at::Half *b, int64_t ldb, at::Half beta, at::Half *c, int64_t ldc) { at::cuda::blas::gemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); @@ -132,261 +31,3 @@ void THCudaBlas_Dgemm(THCState *state, char transa, char transb, int64_t m, int6 { at::cuda::blas::gemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } - -#ifndef __HIP_PLATFORM_HCC__ -#if defined(CUDA_VERSION) && CUDA_VERSION >= 11200 -#define cublasGemmStridedBatchedExFix cublasGemmStridedBatchedEx -#else -// Workaround for https://github.com/pytorch/pytorch/issues/45724 -cublasStatus_t cublasGemmStridedBatchedExFix(cublasHandle_t &handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const void *alpha, - const void *A, - cudaDataType Atype, - int lda, - long long int strideA, - const void *B, - cudaDataType Btype, - int ldb, - long long int strideB, - const void *beta, - void *C, - cudaDataType Ctype, - int ldc, - long long int strideC, - int64_t batchCount, - cudaDataType computeType, - cublasGemmAlgo_t algo) -{ - cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties(); - if (prop->major != 7) { - return cublasGemmStridedBatchedEx(handle, transa, transb, m, n, k, alpha, A, Atype, lda, strideA, B, Btype, ldb, strideB, beta, C, Ctype, ldc, strideC, batchCount, computeType, algo); - } - cublasStatus_t result; - constexpr int64_t split = 63 * 1024; - for(int64_t i = 0; i < batchCount; i += split) { - int64_t count = std::min(split, batchCount - i); - result = cublasGemmStridedBatchedEx(handle, transa, transb, m, n, k, alpha, - (char *)A + i * strideA * 2, Atype, lda, strideA, - (char *)B + i * strideB * 2, Btype, ldb, strideB, - beta, - (char *)C + i * strideC * 2, Ctype, ldc, strideC, - (int)count, computeType, algo); - THCublasCheck(result); - } - return result; -} -#endif -#endif - -void THCudaBlas_HgemmStridedBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k, - at::Half alpha, const at::Half *a, int64_t lda, int64_t strideA, const at::Half *b, int64_t ldb, int64_t strideB, - at::Half beta, at::Half *c, int64_t ldc, int64_t strideC, int64_t batchCount) -{ - // See Note [Writing Nondeterministic Operations] - at::globalContext().alertCuBLASConfigNotDeterministic(); - if( (m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) || (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX) ) - - { - THError("Cublas_SgemmStridedBatched only supports m, n, k, lda, ldb, ldc, batchCount" - "with the bound [val] <= %d", INT_MAX); - } - - adjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); - cublasOperation_t opa = convertTransToCublasOperation(transa); - cublasOperation_t opb = convertTransToCublasOperation(transb); - - cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); - float fAlpha = alpha; - float fBeta = beta; -#ifdef __HIP_PLATFORM_HCC__ - THCublasCheck(rocblas_gemm_strided_batched_ex(handle, opa, opb, (int)m, (int)n, (int)k, - (void*)&fAlpha, a, rocblas_datatype_f16_r, (int)lda, strideA, - b, rocblas_datatype_f16_r, (int)ldb, strideB, - (void*)&fBeta, c, rocblas_datatype_f16_r, (int)ldc, strideC, - c, rocblas_datatype_f16_r, (int)ldc, strideC, - (int) batchCount, rocblas_datatype_f32_r, rocblas_gemm_algo_standard, - 0, 0)); -#else -#if defined(CUDA_VERSION) && CUDA_VERSION < 11000 - // On CUDA versions prior to 11, users are required to set the math mode to CUBLAS_TENSOR_OP_MATH - // manually to be able to use tensor cores for FP16. On CUDA 11, this is no longer required. - THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); -#endif // CUDA_VERSION < 11000 - THCublasCheck(cublasGemmStridedBatchedExFix(handle, - opa, opb, (int)m, (int)n, (int)k, - (void*)&fAlpha, a, CUDA_R_16F, (int)lda, strideA, - b, CUDA_R_16F, (int)ldb, strideB, - (void*)&fBeta, c, CUDA_R_16F, (int)ldc, strideC, - (int)batchCount, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); -#if defined(CUDA_VERSION) && CUDA_VERSION < 11000 - // On CUDA versions prior to 11, users are required to set the math mode to CUBLAS_TENSOR_OP_MATH - // manually to be able to use tensor cores for FP16. On CUDA 11, this is no longer required. - THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); -#endif // CUDA_VERSION < 11000 -#endif // __HIP_PLATFORM_HCC__ -} - -void THCudaBlas_BgemmStridedBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k, - at::BFloat16 alpha, const at::BFloat16 *a, int64_t lda, int64_t strideA, const at::BFloat16 *b, int64_t ldb, int64_t strideB, - at::BFloat16 beta, at::BFloat16 *c, int64_t ldc, int64_t strideC, int64_t batchCount) -{ - at::globalContext().alertCuBLASConfigNotDeterministic(); - if( (m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) || (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX) ) - - { - THError("Cublas_SgemmStridedBatched only supports m, n, k, lda, ldb, ldc, batchCount" - "with the bound [val] <= %d", INT_MAX); - } - - - adjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); - cublasOperation_t opa = convertTransToCublasOperation(transa); - cublasOperation_t opb = convertTransToCublasOperation(transb); - - cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); - float fAlpha = alpha; - float fBeta = beta; - -#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 - cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties(); - if (prop->major < 8) { - TORCH_CHECK(false, "BFloat16 gemm in CUDA requires Ampere or later GPU"); - } - THCublasCheck(cublasGemmStridedBatchedExFix(handle, - opa, opb, (int)m, (int)n, (int)k, - (void*)&fAlpha, a, CUDA_R_16BF, (int)lda, strideA, - b, CUDA_R_16BF, (int)ldb, strideB, - (void*)&fBeta, c, CUDA_R_16BF, (int)ldc, strideC, - (int)batchCount, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); -#elif defined(__HIP_PLATFORM_HCC__) - THCublasCheck(rocblas_gemm_strided_batched_ex(handle, opa, opb, (int)m, (int)n, (int)k, - (void*)&fAlpha, a, rocblas_datatype_bf16_r, (int)lda, strideA, - b, rocblas_datatype_bf16_r, (int)ldb, strideB, - (void*)&fBeta, c, rocblas_datatype_bf16_r, (int)ldc, strideC, - c, rocblas_datatype_bf16_r, (int)ldc, strideC, - (int) batchCount, rocblas_datatype_f32_r, rocblas_gemm_algo_standard, - 0, 0, NULL, NULL)); -#else - TORCH_CHECK(false, "THCudaBlas_BgemmStridedBatched is only available on CUDA_VERSION >= 11"); -#endif // defined(CUDA_VERSION) && CUDA_VERSION >= 11000 -} - -void THCudaBlas_SgemmBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k, - float alpha, const float *a[], int64_t lda, const float *b[], int64_t ldb, - float beta, float *c[], int64_t ldc, int64_t batchCount) -{ - // See Note [Writing Nondeterministic Operations] - at::globalContext().alertCuBLASConfigNotDeterministic(); - if( (m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) || (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX) ) - { - THError("Cublas_SgemmBatched only supports m, n, k, lda, ldb, ldc, batchCount" - "with the bound [val] <= %d", INT_MAX); - } - -#ifdef __HIP_PLATFORM_HCC__ - - const int64_t stridea = (transa == 'N' || transa == 'n') ? lda*k : lda*n; - const int64_t strideb = (transb == 'N' || transb == 'n') ? ldb*n : ldb*k; - const int64_t stridec = ldc*n; - - THCudaBlas_SgemmStridedBatched(state, transa, transb, m, n, k, alpha, *a, lda, stridea, *b, ldb, strideb, beta, *c, ldc, stridec, batchCount); - -#else - - adjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); - cublasOperation_t opa = convertTransToCublasOperation(transa); - cublasOperation_t opb = convertTransToCublasOperation(transb); - - cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); - THCublasCheck(cublasSgemmBatched(handle, - opa, opb, (int)m, (int)n, (int)k, - &alpha, a, (int)lda, b, (int)ldb, &beta, c, (int)ldc, - (int)batchCount)); -#endif -} - -void THCudaBlas_SgemmStridedBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k, - float alpha, const float *a, int64_t lda, int64_t strideA, const float *b, int64_t ldb, int64_t strideB, - float beta, float *c, int64_t ldc, int64_t strideC, int64_t batchCount) -{ - // See Note [Writing Nondeterministic Operations] - at::globalContext().alertCuBLASConfigNotDeterministic(); - if( (m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) || (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX) ) - - { - THError("Cublas_SgemmStridedBatched only supports m, n, k, lda, ldb, ldc, batchCount" - "with the bound [val] <= %d", INT_MAX); - } - - adjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); - cublasOperation_t opa = convertTransToCublasOperation(transa); - cublasOperation_t opb = convertTransToCublasOperation(transb); - - cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); - THCublasCheck(cublasSgemmStridedBatched(handle, - opa, opb, (int)m, (int)n, (int)k, - &alpha, a, (int)lda, strideA, b, (int)ldb, strideB, &beta, c, (int)ldc, strideC, - (int)batchCount)); -} - -void THCudaBlas_DgemmBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k, - double alpha, const double *a[], int64_t lda, const double *b[], int64_t ldb, - double beta, double *c[], int64_t ldc, int64_t batchCount) -{ - // See Note [Writing Nondeterministic Operations] - at::globalContext().alertCuBLASConfigNotDeterministic(); - if( (m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) || (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX) ) - { - THError("Cublas_DgemmBatched only supports m, n, k, lda, ldb, ldc, batchCount" - "with the bound [val] <= %d", INT_MAX); - } - -#ifdef __HIP_PLATFORM_HCC__ - - const int64_t stridea = (transa == 'N' || transa == 'n') ? lda*k : lda*n; - const int64_t strideb = (transb == 'N' || transb == 'n') ? ldb*n : ldb*k; - const int64_t stridec = ldc*n; - - THCudaBlas_DgemmStridedBatched(state, transa, transb, m, n, k, alpha, *a, lda, stridea, *b, ldb, strideb, beta, *c, ldc, stridec, batchCount); - -#else - - adjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); - cublasOperation_t opa = convertTransToCublasOperation(transa); - cublasOperation_t opb = convertTransToCublasOperation(transb); - - cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); - THCublasCheck(cublasDgemmBatched(handle, - opa, opb, (int)m, (int)n, (int)k, - &alpha, a, (int)lda, b, (int)ldb, &beta, c, (int)ldc, - (int)batchCount)); -#endif -} - -void THCudaBlas_DgemmStridedBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k, - double alpha, const double *a, int64_t lda, int64_t strideA, const double *b, int64_t ldb, int64_t strideB, - double beta, double *c, int64_t ldc, int64_t strideC, int64_t batchCount) -{ - // See Note [Writing Nondeterministic Operations] - at::globalContext().alertCuBLASConfigNotDeterministic(); - if( (m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) || (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX) ) - { - THError("Cublas_DgemmBatched only supports m, n, k, lda, ldb, ldc, batchCount" - "with the bound [val] <= %d", INT_MAX); - } - - adjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); - cublasOperation_t opa = convertTransToCublasOperation(transa); - cublasOperation_t opb = convertTransToCublasOperation(transb); - - cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); - THCublasCheck(cublasDgemmStridedBatched(handle, - opa, opb, (int)m, (int)n, (int)k, - &alpha, a, (int)lda, strideA, b, (int)ldb, strideB, &beta, c, (int)ldc, strideC, - (int)batchCount)); -} diff --git a/aten/src/THC/THCBlas.h b/aten/src/THC/THCBlas.h index 4078363eb888..7d537da28be3 100644 --- a/aten/src/THC/THCBlas.h +++ b/aten/src/THC/THCBlas.h @@ -5,10 +5,6 @@ #include #include -/* Level 2 */ -THC_API void THCudaBlas_Sger(THCState *state, int64_t m, int64_t n, float alpha, float *x, int64_t incx, float *y, int64_t incy, float *a, int64_t lda); -THC_API void THCudaBlas_Dger(THCState *state, int64_t m, int64_t n, double alpha, double *x, int64_t incx, double *y, int64_t incy, double *a, int64_t lda); - /* Level 3 */ THC_API void THCudaBlas_Sgemm(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k, float alpha, float *a, int64_t lda, float *b, int64_t ldb, float beta, float *c, int64_t ldc); THC_API void THCudaBlas_Dgemm(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k, double alpha, double *a, int64_t lda, double *b, int64_t ldb, double beta, double *c, int64_t ldc); @@ -17,25 +13,4 @@ THC_API void THCudaBlas_Hgemm(THCState *state, char transa, char transb, int64_t THC_API void THCudaBlas_Bgemm(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k, at::BFloat16 alpha, at::BFloat16 *a, int64_t lda, at::BFloat16 *b, int64_t ldb, at::BFloat16 beta, at::BFloat16 *c, int64_t ldc); -THC_API void THCudaBlas_SgemmBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k, - float alpha, const float *a[], int64_t lda, const float *b[], int64_t ldb, - float beta, float *c[], int64_t ldc, int64_t batchCount); -THC_API void THCudaBlas_DgemmBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k, - double alpha, const double *a[], int64_t lda, const double *b[], int64_t ldb, - double beta, double *c[], int64_t ldc, int64_t batchCount); -THC_API void THCudaBlas_SgemmStridedBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k, - float alpha, const float *a, int64_t lda, int64_t strideA, const float *b, int64_t ldb, int64_t strideB, - float beta, float *c, int64_t ldc, int64_t strideC, int64_t batchCount); -THC_API void THCudaBlas_DgemmStridedBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k, - double alpha, const double *a, int64_t lda, int64_t strideA, const double *b, int64_t ldb, int64_t strideB, - double beta, double *c, int64_t ldc, int64_t strideC, int64_t batchCount); - -void THCudaBlas_HgemmStridedBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k, - THHalf alpha, const THHalf *a, int64_t lda, int64_t strideA, const THHalf *b, int64_t ldb, int64_t strideB, - THHalf beta, THHalf *c, int64_t ldc, int64_t strideC, int64_t batchCount); - -void THCudaBlas_BgemmStridedBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k, - at::BFloat16 alpha, const at::BFloat16 *a, int64_t lda, int64_t strideA, const at::BFloat16 *b, int64_t ldb, int64_t strideB, - at::BFloat16 beta, at::BFloat16 *c, int64_t ldc, int64_t strideC, int64_t batchCount); - #endif diff --git a/aten/src/THC/THCTensorMath.h b/aten/src/THC/THCTensorMath.h index 68fbb240afb4..fd316f93ed55 100644 --- a/aten/src/THC/THCTensorMath.h +++ b/aten/src/THC/THCTensorMath.h @@ -13,12 +13,6 @@ #include #include -#include -#include - -#include -#include - #include #include diff --git a/aten/src/THC/THCTensorMathBlas.cu b/aten/src/THC/THCTensorMathBlas.cu deleted file mode 100644 index 383d1ed17b1d..000000000000 --- a/aten/src/THC/THCTensorMathBlas.cu +++ /dev/null @@ -1,13 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include - -#include -#include - -#include -#include diff --git a/aten/src/THC/generic/THCTensorMathBlas.cu b/aten/src/THC/generic/THCTensorMathBlas.cu deleted file mode 100644 index a5d159a9cace..000000000000 --- a/aten/src/THC/generic/THCTensorMathBlas.cu +++ /dev/null @@ -1,326 +0,0 @@ -#ifndef THC_GENERIC_FILE -#define THC_GENERIC_FILE "THC/generic/THCTensorMathBlas.cu" -#else - -#include -#include - -#define ERROR_ONLY_FP_TYPES(func) \ - THError("%s for CUDA tensors only supports floating-point types. Try converting the tensors with .float()", func); - -__global__ void createBatchGemmBuffer3(const scalar_t** buffer1, const scalar_t ** buffer2, const scalar_t ** buffer3, scalar_t* data1, - scalar_t * data2, scalar_t * data3, int64_t stride1, int64_t stride2, int64_t stride3, int64_t num_batches) { - const int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx < num_batches) { - buffer1[idx] = data1 + idx * stride1; - buffer2[idx] = data2 + idx * stride2; - buffer3[idx] = data3 + idx * stride3; - } -} - -void THCTensor_(baddbmm)(THCState *state, THCTensor *result, THCTensor *t, - THCTensor *batch1, THCTensor *batch2, - scalar_t beta, scalar_t alpha) { -#if defined(THC_REAL_IS_HALF) || defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_BFLOAT16) - THCAssertSameGPU(THCTensor_(checkGPU)(state, 4, result, t, batch1, batch2)); - THArgCheck(THCTensor_(nDimensionLegacyNoScalars)(state, t) == 3, 4, "expected 3D tensor"); - THArgCheck(THCTensor_(nDimensionLegacyNoScalars)(state, batch1) == 3, 6, "expected 3D tensor"); - THArgCheck(THCTensor_(nDimensionLegacyNoScalars)(state, batch2) == 3, 7, "expected 3D tensor"); - THArgCheck(THCTensor_(size)(state, t, 0) == THCTensor_(size)(state, batch1, 0), 6, - "equal number of batches expected"); - THArgCheck(THCTensor_(size)(state, t, 0) == THCTensor_(size)(state, batch2, 0), 7, - "equal number of batches expected"); - auto maybe_outnames = at::namedinference::compute_baddbmm_outnames(result, batch1, batch2, t); - { - at::NoNamesGuard guard; - THArgCheck(THCTensor_(size)(state, t, 1) == THCTensor_(size)(state, batch1, 1), 6, - "wrong matrix size"); - THArgCheck(THCTensor_(size)(state, t, 2) == THCTensor_(size)(state, batch2, 2), 7, - "wrong matrix size"); - THArgCheck(THCTensor_(size)(state, batch1, 2) == THCTensor_(size)(state, batch2, 1), 6, - "wrong matrix size"); - - if (t != result) { - THCTensor_(resizeAs)(state, result, t); - if (ScalarConvert::to(beta) != 0.0) { - THCTensor_(copy)(state, result, t); - } - } - - bool transpose_result; - char transpose_batch1, transpose_batch2; - int64_t lda, ldb, ldc; - THCTensor *result_, *batch1_, *batch2_; - if (result->stride(1) == 1 && - (result->size(2) == 1 || result->stride(2) >= std::max(1, result->size(1)))) - { - transpose_result = false; - result_ = result; - ldc = result_->stride(2); - } - else if (result->stride(2) == 1 && - (result->size(1) == 1 || result->stride(1) >= std::max(1, result->size(2)))) - { - transpose_result = true; - - THCTensor *swap = batch2; - batch2 = batch1; - batch1 = swap; - - result_ = result; - ldc = result_->stride(1); - } - else - { - transpose_result = false; - - THCTensor *transp_r_ = THCTensor_(newTranspose)(state, result, 1, 2); - result_ = THCTensor_(newClone)(state, transp_r_); - THCTensor_(free)(state, transp_r_); - THCTensor_(transpose)(state, result_, NULL, 1, 2); - - ldc = result_->stride(2); - } - - const int64_t m = result->size(transpose_result ? 2 : 1); - const int64_t n = result->size(transpose_result ? 1 : 2); - const int64_t k = batch1->size(transpose_result ? 1 : 2); - - if (batch1->stride(transpose_result ? 2 : 1) == 1 && - batch1->stride(transpose_result ? 1 : 2) >= std::max(1, m)) - { - transpose_batch1 = 'n'; - batch1_ = batch1; - lda = batch1_->stride(transpose_result ? 1 : 2); - } - else if (batch1->stride(transpose_result ? 1 : 2) == 1 && - batch1->stride(transpose_result ? 2 : 1) >= std::max(1, k)) - { - transpose_batch1 = 't'; - batch1_ = batch1; - lda = batch1_->stride(transpose_result ? 2 : 1); - } - else - { - transpose_batch1 = transpose_result ? 'n' : 't'; - // batch1_ is later freed if batch1_ != batch1 - if (THCTensor_(isContiguous)(state, batch1)) { - batch1_ = batch1; - } else { - batch1_ = THCTensor_(newContiguous)(state, batch1); - } - lda = batch1_->stride(1); - } - - if (batch2->stride(transpose_result ? 2 : 1) == 1 && - batch2->stride(transpose_result ? 1 : 2) >= std::max(1, k)) - { - transpose_batch2 = 'n'; - batch2_ = batch2; - ldb = batch2_->stride(transpose_result ? 1 : 2); - } - else if (batch2->stride(transpose_result ? 1 : 2) == 1 && - batch2->stride(transpose_result ? 2 : 1) >= std::max(1, n)) - { - transpose_batch2 = 't'; - batch2_ = batch2; - ldb = batch2_->stride(transpose_result ? 2 : 1); - } - else - { - transpose_batch2 = transpose_result ? 'n' : 't'; - // batch2_ is later freed if batch2_ != batch2 - if (THCTensor_(isContiguous)(state, batch2)) { - batch2_ = batch2; - } else { - batch2_ = THCTensor_(newContiguous)(state, batch2); - } - ldb = batch2_->stride(1); - } - int64_t num_batches = result_->size(0); - -#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) - // Compute pointers to matrices in each batch. -#if CUDA_VERSION < 8000 && !defined __HIP_PLATFORM_HCC__ - size_t matrices_size = num_batches * sizeof(scalar_t*); - -// Copy pointers to device. - auto d_matrices1 = static_cast(THCudaMalloc(state, matrices_size)); - auto d_matrices2 = static_cast(THCudaMalloc(state, matrices_size)); - auto d_result_matrices = static_cast(THCudaMalloc(state, matrices_size)); - - const int64_t block = 512; - const int64_t grid = (num_batches + block - 1) / block; - - createBatchGemmBuffer3<<>>( - d_matrices1, d_matrices2, (const scalar_t**)d_result_matrices, THCTensor_(data)(state, batch1_), - THCTensor_(data)(state, batch2_), THCTensor_(data)(state, result_), - batch1_->stride(0), batch2_->stride(0), result_->stride(0), num_batches); - -#ifdef THC_REAL_IS_FLOAT - THCudaBlas_SgemmBatched( - state, - transpose_batch1, - transpose_batch2, - result_->size(transpose_result ? 2 : 1), - result_->size(transpose_result ? 1 : 2), - batch1_->size(transpose_result ? 1 : 2), - alpha, - d_matrices1, lda, - d_matrices2, ldb, - beta, - d_result_matrices, ldc, - num_batches); -#elif defined(THC_REAL_IS_DOUBLE) - THCudaBlas_DgemmBatched( - state, - transpose_batch1, - transpose_batch2, - result_->size(transpose_result ? 2 : 1), - result_->size(transpose_result ? 1 : 2), - batch1_->size(transpose_result ? 1 : 2), - alpha, - d_matrices1, lda, - d_matrices2, ldb, - beta, - d_result_matrices, ldc, - num_batches); -#endif //THC_REAL - - THCudaFree(state, d_matrices1); - THCudaFree(state, d_matrices2); - THCudaFree(state, d_result_matrices); - -#else -#ifdef THC_REAL_IS_FLOAT - THCudaBlas_SgemmStridedBatched( - state, - transpose_batch1, - transpose_batch2, - result_->size(transpose_result ? 2 : 1), - result_->size(transpose_result ? 1 : 2), - batch1_->size(transpose_result ? 1 : 2), - alpha, - THCTensor_(data)(state, batch1_), lda, batch1_->stride(0), - THCTensor_(data)(state, batch2_), ldb, batch2_->stride(0), - beta, - THCTensor_(data)(state, result_), ldc, result_->stride(0), - num_batches); -#elif defined(THC_REAL_IS_DOUBLE) - THCudaBlas_DgemmStridedBatched( - state, - transpose_batch1, - transpose_batch2, - result_->size(transpose_result ? 2 : 1), - result_->size(transpose_result ? 1 : 2), - batch1_->size(transpose_result ? 1 : 2), - alpha, - THCTensor_(data)(state, batch1_), lda, batch1_->stride(0), - THCTensor_(data)(state, batch2_), ldb, batch2_->stride(0), - beta, - THCTensor_(data)(state, result_), ldc, result_->stride(0), - num_batches); -#endif //THC_REAL -#endif //CUDA_VERSION - -#elif defined(THC_REAL_IS_HALF) - -#if CUDA_VERSION < 9010 && !defined(__HIP_PLATFORM_HCC__) - // Currently no HgemmBatched in Cublas - for (int64_t i = 0; i < num_batches; ++i) { - THCudaBlas_Hgemm( - state, - transpose_batch1, - transpose_batch2, - result_->size(transpose_result ? 2 : 1), - result_->size(transpose_result ? 1 : 2), - batch1_->size(transpose_result ? 1 : 2), - alpha, - THCTensor_(data)(state, batch1_) + i * batch1_->stride(0), lda, - THCTensor_(data)(state, batch2_) + i * batch2_->stride(0), ldb, - beta, - THCTensor_(data)(state, result_) + i * result_->stride(0), ldc); - } -#else -#ifndef __HIP_PLATFORM_HCC__ - cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties(); - if (prop->major >= 5){ -#endif - - THCudaBlas_HgemmStridedBatched( - state, - transpose_batch1, - transpose_batch2, - result_->size(transpose_result ? 2 : 1), - result_->size(transpose_result ? 1 : 2), - batch1_->size(transpose_result ? 1 : 2), - alpha, - THCTensor_(data)(state, batch1_), lda, batch1_->stride(0), - THCTensor_(data)(state, batch2_), ldb, batch2_->stride(0), - beta, - THCTensor_(data)(state, result_), ldc, result_->stride(0), - num_batches); -#ifndef __HIP_PLATFORM_HCC__ - } else { - for (int64_t i = 0; i < num_batches; ++i) { - THCudaBlas_Hgemm( - state, - transpose_batch1, - transpose_batch2, - result_->size(transpose_result ? 2 : 1), - result_->size(transpose_result ? 1 : 2), - batch1_->size(transpose_result ? 1 : 2), - alpha, - THCTensor_(data)(state, batch1_) + i * batch1_->stride(0), lda, - THCTensor_(data)(state, batch2_) + i * batch2_->stride(0), ldb, - beta, - THCTensor_(data)(state, result_) + i * result_->stride(0), ldc); - } - } -#endif -#endif //CUDA_VERSION - -#elif defined(THC_REAL_IS_BFLOAT16) -#if defined(__HIP_PLATFORM_HCC__) || defined(CUDA_VERSION) && CUDA_VERSION >= 11000 - THCudaBlas_BgemmStridedBatched( - state, - transpose_batch1, - transpose_batch2, - result_->size(transpose_result ? 2 : 1), - result_->size(transpose_result ? 1 : 2), - batch1_->size(transpose_result ? 1 : 2), - alpha, - THCTensor_(data)(state, batch1_), lda, batch1_->stride(0), - THCTensor_(data)(state, batch2_), ldb, batch2_->stride(0), - beta, - THCTensor_(data)(state, result_), ldc, result_->stride(0), - num_batches); -#endif // __HIP_PLATFORM_HCC__ -#endif - - if (batch1_ != batch1) { - THCTensor_(free)(state, batch1_); - } - - if (batch2_ != batch2) { - THCTensor_(free)(state, batch2_); - } - - if (result_ != result) { - THCTensor_(freeCopyTo)(state, result_, result); - } - -#if defined(THC_REAL_IS_BFLOAT16) && !(defined(__HIP_PLATFORM_HCC__) || defined(CUDA_VERSION) && CUDA_VERSION >= 11000) - // To avoid "variable was set but never used" warning - [&transpose_batch1, &transpose_batch2, &lda, &ldb, &ldc]{}(); - TORCH_CHECK(false, "BgemmStridedBatched is not supported with at::BFloat16 type"); -#endif - } - at::namedinference::propagate_names_if_nonempty(result, maybe_outnames); - -#else - ERROR_ONLY_FP_TYPES("baddbmm"); -#endif -} - -#endif diff --git a/aten/src/THC/generic/THCTensorMathBlas.h b/aten/src/THC/generic/THCTensorMathBlas.h deleted file mode 100644 index e15baafaca64..000000000000 --- a/aten/src/THC/generic/THCTensorMathBlas.h +++ /dev/null @@ -1,7 +0,0 @@ -#ifndef THC_GENERIC_FILE -#define THC_GENERIC_FILE "THC/generic/THCTensorMathBlas.h" -#else - -THC_API void THCTensor_(baddbmm)(THCState *state, THCTensor *result, THCTensor *t, THCTensor *batch1, THCTensor *batch2, scalar_t beta, scalar_t alpha); - -#endif diff --git a/test/test_torch.py b/test/test_torch.py index 2e310f34be6c..c592411fe25a 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -17658,10 +17658,12 @@ def test_bmm(self, device, dtype): (self.device_type == 'cuda' and dtype in cuda_supported_dtypes) if not is_supported: + return + # NOTE: code below has been temporarily short circuited for unsupported types + # since they are supported for some code paths and don't always throw an error. b1 = torch.randn(num_batches, M, N, device=device).to(dtype) b2 = torch.randn(num_batches, N, O, device=device).to(dtype) self.assertRaisesRegex(RuntimeError, "type|Type|not implemented|Ampere", lambda: torch.bmm(b1, b2)) - return def invert_perm(p): d = {x: i for i, x in enumerate(p)} @@ -17878,11 +17880,13 @@ def test_baddbmm(self, device, dtype): (self.device_type == 'cuda' and dtype in cuda_supported_dtypes) if not is_supported: + return + # NOTE: code below has been temporarily short circuited for unsupported types + # since they are supported for some code paths and don't always throw an error. b1 = make_tensor((num_batches, M, N), device, dtype, low=-1, high=1) b2 = make_tensor((num_batches, N, O), device, dtype, low=-1, high=1) t = make_tensor((num_batches, M, O), device, dtype, low=-1, high=1) self.assertRaisesRegex(RuntimeError, "type|Type|not implemented|Ampere", lambda: torch.baddbmm(t, b1, b2)) - return def invert_perm(p): d = {x: i for i, x in enumerate(p)} @@ -20424,11 +20428,11 @@ def inner(self, device, dtype): ('baddbmm', '', _small_3d, lambda t, d: [_small_3d(t, d), _small_3d(t, d)], 1e-2, 1e-1, 1e-4, torch.testing.get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM)), ('baddbmm', 'scalar', _small_3d, lambda t, d: [_number(0.4, 2, t), _small_3d(t, d), _small_3d(t, d)], - 1e-2, 1e-1, 1e-4, torch.testing.get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM), _cpu_types, True, - [_wrap_maybe_warns("This overload of baddbmm_? is deprecated")]), + 1e-2, 1e-1, 1e-4, torch.testing.get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM) + _complex_types, _cpu_types, True, + [tf32_on_and_off(0.05), _wrap_maybe_warns("This overload of baddbmm_? is deprecated")]), ('baddbmm', 'two_scalars', _small_3d, lambda t, d: [_number(0.5, 3, t), _number(0.4, 2, t), _small_3d(t, d), _small_3d(t, d)], - 1e-2, 1e-1, 1e-4, torch.testing.get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM), _cpu_types, True, - [_wrap_maybe_warns("This overload of baddbmm_? is deprecated")]), + 1e-2, 1e-1, 1e-4, torch.testing.get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM) + _complex_types, + _cpu_types, True, [tf32_on_and_off(0.05), _wrap_maybe_warns("This overload of baddbmm_? is deprecated")]), ('bmm', '', _small_3d, lambda t, d: [_small_3d(t, d)], 1e-5, 1e-5, 1e-5, _float_types_no_half, _cpu_types, False), ('addcdiv', '', _small_2d, From cfe3defd88b43ba710dd1093e382c5e8c279bd83 Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Thu, 12 Nov 2020 08:02:08 -0800 Subject: [PATCH 51/93] [vulkan] Enable prepacked addmm/mm for linear layers (#47815) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47815 Test Plan: Imported from OSS Reviewed By: IvanKobzarev Differential Revision: D24908605 Pulled By: SS-JIA fbshipit-source-id: e658bc2dbf23d5d911b979d3b8f467508f2fdf0c --- .../ATen/native/vulkan/ops/Convolution.cpp | 358 +++++++----------- aten/src/ATen/native/vulkan/ops/Convolution.h | 99 +++++ aten/src/ATen/native/vulkan/ops/Mm.cpp | 285 +++++++++----- aten/src/ATen/native/vulkan/ops/Mm.h | 55 +++ .../vulkan/ops/RegisterOpContextClass.cpp | 80 ++++ torch/csrc/jit/passes/vulkan_rewrite.cpp | 52 ++- 6 files changed, 607 insertions(+), 322 deletions(-) create mode 100644 aten/src/ATen/native/vulkan/ops/Convolution.h create mode 100644 aten/src/ATen/native/vulkan/ops/Mm.h create mode 100644 aten/src/ATen/native/vulkan/ops/RegisterOpContextClass.cpp diff --git a/aten/src/ATen/native/vulkan/ops/Convolution.cpp b/aten/src/ATen/native/vulkan/ops/Convolution.cpp index 5bec92abb53d..7cf3b4fe5137 100644 --- a/aten/src/ATen/native/vulkan/ops/Convolution.cpp +++ b/aten/src/ATen/native/vulkan/ops/Convolution.cpp @@ -1,8 +1,7 @@ -#include +#include #include #include #include -#include namespace at { namespace native { @@ -10,74 +9,6 @@ namespace vulkan { namespace ops { namespace { -class Context final : public torch::jit::CustomClassHolder { - public: - static Context create( - api::Resource::Pool& pool, - const Tensor& weight, - const c10::optional& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - bool transposed, - IntArrayRef output_padding, - int64_t groups, - c10::optional output_min = c10::nullopt, - c10::optional output_max = c10::nullopt); - - using State = std::tuple< - Tensor, - c10::optional, - std::vector, - std::vector, - std::vector, - int64_t, - c10::optional, - c10::optional>; - - Tensor run(const Tensor& input) const; - State unpack() const; - - private: - Context( - api::Resource::Pool& pool, - const Tensor& weight, - const c10::optional& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - bool transposed, - IntArrayRef output_padding, - int64_t groups, - c10::optional output_min = c10::nullopt, - c10::optional output_max = c10::nullopt); - - private: - struct { - vTensor v_weight; - vTensor v_bias; - std::array filter; - std::array stride; - std::array padding; - std::array dilation; - int32_t groups; - float output_min; - float output_max; - } packed_; - - struct { - Tensor weight; - c10::optional bias; - std::vector filter; - std::vector stride; - std::vector padding; - std::vector dilation; - int64_t groups; - c10::optional output_min; - c10::optional output_max; - } unpacked_; -}; - inline bool is_depthwise( const IntArrayRef filter, const int64_t groups) { @@ -263,42 +194,6 @@ std::array pack_params(const std::vector& vector) { }; } -Context::Context( - api::Resource::Pool& pool, - const Tensor& weight, - const c10::optional& bias, - const IntArrayRef stride, - const IntArrayRef padding, - const IntArrayRef dilation, - const bool /* transposed */, - const IntArrayRef /* output_padding */, - const int64_t groups, - const c10::optional output_min, - const c10::optional output_max) - : packed_{ - pack_weights(pool, weight, groups), - pack_biases(pool, bias, weight), - pack_filter(weight, expand_param_if_needed(dilation, "dilation", 2)), - pack_params(expand_param_if_needed(stride, "stride", 2)), - pack_params(expand_param_if_needed(padding, "padding", 2)), - pack_params(expand_param_if_needed(dilation, "dilation", 2)), - groups, - output_min ? output_min->template to() : -std::numeric_limits::infinity(), - output_max ? output_max->template to() : +std::numeric_limits::infinity(), - }, - unpacked_{ - weight, - bias, - weight.sizes().vec(), - stride.vec(), - padding.vec(), - dilation.vec(), - groups, - output_min, - output_max, - } { -} - bool available( const Tensor& weight, const c10::optional& bias, @@ -349,56 +244,6 @@ bool available( true; } -Context Context::create( - api::Resource::Pool& pool, - const Tensor& weight, - const c10::optional& bias, - const IntArrayRef stride_arg, - const IntArrayRef padding_arg, - const IntArrayRef dilation_arg, - const bool transposed, - const IntArrayRef output_padding_arg, - const int64_t groups, - const c10::optional output_min, - const c10::optional output_max) { - const auto stride = expand_param_if_needed(stride_arg, "stride", 2); - const auto padding = expand_param_if_needed(padding_arg, "padding", 2); - const auto dilation = expand_param_if_needed(dilation_arg, "dilation", 2); - const auto output_padding = output_padding_arg; // TODO: Deconvolutions - - TORCH_CHECK( - available( - weight, - bias, - stride, - padding, - dilation, - transposed, - output_padding, - groups, - output_min, - output_max), - "Vulkan::convolution not available! " - "Reason: The provided (weight, bias, stride, padding, dilation, groups, " - "transposed, output_padding, output_min, output_max) parameters are either " - "invalid individually or their combination is not supported by Vulkan impl."); - - // Pass in the originals - return Context{ - pool, - weight, - bias, - stride_arg, - padding_arg, - dilation_arg, - transposed, - output_padding_arg, - groups, - output_min, - output_max, - }; -} - bool usable(const Tensor& input) { // Input return (4 == input.ndimension()) && @@ -632,7 +477,126 @@ void conv2d( } } -Tensor Context::run(const Tensor& input_arg) const { +Tensor convolution( + const Tensor& input, + const Tensor& weight, + const c10::optional& bias, + const IntArrayRef stride, + const IntArrayRef padding, + const IntArrayRef dilation, + const bool transposed, + const IntArrayRef output_padding, + const int64_t groups) { + return Conv2dOpContext::create( + api::context()->resource().pool, + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups + ).run(input); +} + +#ifdef USE_VULKAN_API + +TORCH_LIBRARY_IMPL(aten, Vulkan, m) { + m.impl_UNBOXED("convolution_overrideable", convolution); +} + +#endif /* USE_VULKAN_API */ + +} // namespace + +Conv2dOpContext::Conv2dOpContext( + api::Resource::Pool& pool, + const Tensor& weight, + const c10::optional& bias, + const IntArrayRef stride, + const IntArrayRef padding, + const IntArrayRef dilation, + const bool /* transposed */, + const IntArrayRef /* output_padding */, + const int64_t groups, + const c10::optional output_min, + const c10::optional output_max) + : packed_{ + pack_weights(pool, weight, groups), + pack_biases(pool, bias, weight), + pack_filter(weight, expand_param_if_needed(dilation, "dilation", 2)), + pack_params(expand_param_if_needed(stride, "stride", 2)), + pack_params(expand_param_if_needed(padding, "padding", 2)), + pack_params(expand_param_if_needed(dilation, "dilation", 2)), + groups, + output_min ? output_min->template to() : -std::numeric_limits::infinity(), + output_max ? output_max->template to() : +std::numeric_limits::infinity(), + }, + unpacked_{ + weight, + bias, + weight.sizes().vec(), + stride.vec(), + padding.vec(), + dilation.vec(), + groups, + output_min, + output_max, + } { +} + +Conv2dOpContext Conv2dOpContext::create( + api::Resource::Pool& pool, + const Tensor& weight, + const c10::optional& bias, + const IntArrayRef stride_arg, + const IntArrayRef padding_arg, + const IntArrayRef dilation_arg, + const bool transposed, + const IntArrayRef output_padding_arg, + const int64_t groups, + const c10::optional output_min, + const c10::optional output_max) { + const auto stride = expand_param_if_needed(stride_arg, "stride", 2); + const auto padding = expand_param_if_needed(padding_arg, "padding", 2); + const auto dilation = expand_param_if_needed(dilation_arg, "dilation", 2); + const auto output_padding = output_padding_arg; // TODO: Deconvolutions + + TORCH_CHECK( + available( + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + output_min, + output_max), + "Vulkan::convolution not available! " + "Reason: The provided (weight, bias, stride, padding, dilation, groups, " + "transposed, output_padding, output_min, output_max) parameters are either " + "invalid individually or their combination is not supported by Vulkan impl."); + + // Pass in the originals + return Conv2dOpContext{ + pool, + weight, + bias, + stride_arg, + padding_arg, + dilation_arg, + transposed, + output_padding_arg, + groups, + output_min, + output_max, + }; +} + +Tensor Conv2dOpContext::run(const Tensor& input_arg) const { api::Context* const context = api::context(); const Tensor input = input_arg.is_vulkan() ? input_arg : input_arg.vulkan(); @@ -708,8 +672,8 @@ Tensor Context::run(const Tensor& input_arg) const { return convert(v_output); } -Context::State Context::unpack() const { - return Context::State{ +Conv2dOpContext::State Conv2dOpContext::unpack() const { + return Conv2dOpContext::State{ unpacked_.weight, unpacked_.bias, unpacked_.stride, @@ -721,7 +685,7 @@ Context::State Context::unpack() const { }; } -c10::intrusive_ptr conv2_clamp_prepack( +c10::intrusive_ptr conv2d_clamp_prepack( Tensor&& weight, c10::optional&& bias, std::vector&& stride, @@ -730,8 +694,8 @@ c10::intrusive_ptr conv2_clamp_prepack( const int64_t groups, const c10::optional output_min, const c10::optional output_max) { - return c10::make_intrusive( - Context::create( + return c10::make_intrusive( + Conv2dOpContext::create( persistent()->pool, std::move(weight), std::move(bias), @@ -747,78 +711,10 @@ c10::intrusive_ptr conv2_clamp_prepack( Tensor conv2d_clamp_run( const Tensor& input, - const c10::intrusive_ptr& context) { + const c10::intrusive_ptr& context) { return context->run(input); } -Tensor convolution( - const Tensor& input, - const Tensor& weight, - const c10::optional& bias, - const IntArrayRef stride, - const IntArrayRef padding, - const IntArrayRef dilation, - const bool transposed, - const IntArrayRef output_padding, - const int64_t groups) { - return Context::create( - api::context()->resource().pool, - weight, - bias, - stride, - padding, - dilation, - transposed, - output_padding, - groups - ).run(input); -} - -TORCH_LIBRARY(vulkan, m) { - m.class_("Conv2dOpContext") - .def_pickle( - // __getstate__ - [](const c10::intrusive_ptr& context) { - return context->unpack(); - }, - // __setstate__ - [](Context::State state) { - return conv2_clamp_prepack( - std::move(std::get<0>(state)), - std::move(std::get<1>(state)), - std::move(std::get<2>(state)), - std::move(std::get<3>(state)), - std::move(std::get<4>(state)), - std::move(std::get<5>(state)), - std::move(std::get<6>(state)), - std::move(std::get<7>(state))); - }); -} - -TORCH_LIBRARY(vulkan_prepack, m) { - m.def( - "conv2d_clamp_prepack(Tensor W, Tensor? B, int[2] stride, " - "int[2] padding, int[2] dilation, int groups, " - "Scalar? output_min=None, Scalar? output_max=None) " - "-> __torch__.torch.classes.vulkan.Conv2dOpContext"); - m.def( - "conv2d_clamp_run(Tensor X, " - "__torch__.torch.classes.vulkan.Conv2dOpContext W_prepack) -> Tensor Y"); -} - -TORCH_LIBRARY_IMPL(vulkan_prepack, CPU, m) { - m.impl("conv2d_clamp_prepack", TORCH_FN(conv2_clamp_prepack)); -} - -TORCH_LIBRARY_IMPL(vulkan_prepack, Vulkan, m) { - m.impl("conv2d_clamp_run", conv2d_clamp_run); -} - -TORCH_LIBRARY_IMPL(aten, Vulkan, m) { - m.impl_UNBOXED("convolution_overrideable", convolution); -} - -} // namespace } // namespace ops } // namespace vulkan } // namespace native diff --git a/aten/src/ATen/native/vulkan/ops/Convolution.h b/aten/src/ATen/native/vulkan/ops/Convolution.h new file mode 100644 index 000000000000..2bab7091d4ab --- /dev/null +++ b/aten/src/ATen/native/vulkan/ops/Convolution.h @@ -0,0 +1,99 @@ +#pragma once +#ifdef USE_VULKAN + +#include +#include + +namespace at { +namespace native { +namespace vulkan { +namespace ops { + +class Conv2dOpContext final : public torch::jit::CustomClassHolder { + public: + static Conv2dOpContext create( + api::Resource::Pool& pool, + const Tensor& weight, + const c10::optional& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + bool transposed, + IntArrayRef output_padding, + int64_t groups, + c10::optional output_min = c10::nullopt, + c10::optional output_max = c10::nullopt); + + using State = std::tuple< + Tensor, + c10::optional, + std::vector, + std::vector, + std::vector, + int64_t, + c10::optional, + c10::optional>; + + Tensor run(const Tensor& input) const; + State unpack() const; + + private: + Conv2dOpContext( + api::Resource::Pool& pool, + const Tensor& weight, + const c10::optional& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + bool transposed, + IntArrayRef output_padding, + int64_t groups, + c10::optional output_min = c10::nullopt, + c10::optional output_max = c10::nullopt); + + private: + struct { + vTensor v_weight; + vTensor v_bias; + std::array filter; + std::array stride; + std::array padding; + std::array dilation; + int32_t groups; + float output_min; + float output_max; + } packed_; + + struct { + Tensor weight; + c10::optional bias; + std::vector filter; + std::vector stride; + std::vector padding; + std::vector dilation; + int64_t groups; + c10::optional output_min; + c10::optional output_max; + } unpacked_; +}; + +Tensor conv2d_clamp_run( + const Tensor& input, + const c10::intrusive_ptr& context); + +c10::intrusive_ptr conv2d_clamp_prepack( + Tensor&& weight, + c10::optional&& bias, + std::vector&& stride, + std::vector&& padding, + std::vector&& dilation, + const int64_t groups, + const c10::optional output_min, + const c10::optional output_max); + +} // namespace ops +} // namespace vulkan +} // namespace native +} // namespace at + +#endif /* USE_VULKAN */ diff --git a/aten/src/ATen/native/vulkan/ops/Mm.cpp b/aten/src/ATen/native/vulkan/ops/Mm.cpp index 185f66226e15..ca342e70a7b8 100644 --- a/aten/src/ATen/native/vulkan/ops/Mm.cpp +++ b/aten/src/ATen/native/vulkan/ops/Mm.cpp @@ -1,5 +1,5 @@ -#include -#include +#include +#include namespace at { namespace native { @@ -7,100 +7,113 @@ namespace vulkan { namespace ops { namespace { -Tensor addmm( - const Tensor& self_arg, - const Tensor& mat1_arg, - const Tensor& mat2_arg, - const Scalar beta, - const Scalar alpha) { - api::Context* const context = api::context(); +vTensor pack_weights(api::Resource::Pool& pool, const Tensor& weight_arg) { + return convert(weight_arg.vulkan()); +} - const Tensor self = self_arg.is_vulkan() ? self_arg : self_arg.vulkan(); - const vTensor& v_self = convert(self); +vTensor pack_biases( + api::Resource::Pool& pool, + const c10::optional& bias_arg, + const Tensor& weight_arg) { + if (bias_arg) { + return convert(bias_arg->vulkan()); + } else { + vTensor v_bias{ + api::context(), + &pool, + {weight_arg.size(Layout::Parameter::width)}, + weight_arg.options(), + }; - const Tensor mat1 = mat1_arg.is_vulkan() ? mat1_arg : mat1_arg.vulkan(); - const vTensor& v_mat1 = convert(mat1); + using Future = vTensor::Future; + Future v_bias_future = v_bias.host(); + Future::Payload v_bias_payload = v_bias_future.wait(); - const Tensor mat2 = mat2_arg.is_vulkan() ? mat2_arg : mat2_arg.vulkan(); - const vTensor& v_mat2 = convert(mat2); - - const auto self_sizes = self.sizes(); - const auto mat1_sizes = mat1.sizes(); - const auto mat2_sizes = mat2.sizes(); + memset( + v_bias_payload.get(), + // 2's complement integers and IEEE-754 floating point numbers both + // have identical bit representations for 0, so can use memset which + // only accepts uint8_t parameter. + 0, + v_bias.nbytes()); - if (self_sizes.size() >= 2) { - TORCH_CHECK( - (mat1_sizes[Layout::Parameter::width] == - mat2_sizes[Layout::Parameter::height]) && - (self_sizes[Layout::Parameter::height] == - mat1_sizes[Layout::Parameter::height]) && - (self_sizes[Layout::Parameter::width] == - mat2_sizes[Layout::Parameter::width]), - "Incompatible matrix dimensions!"); + return v_bias; } - else { - TORCH_CHECK( - (mat1_sizes[Layout::Parameter::width] == - mat2_sizes[Layout::Parameter::height]) && - ((self_sizes[Layout::Parameter::height] == - mat1_sizes[Layout::Parameter::height]) || - (self_sizes[Layout::Parameter::height] == - mat2_sizes[Layout::Parameter::width])), - "Incompatible matrix dimensions!"); +} + +bool available(const Tensor& weight, const c10::optional& bias) { + bool valid = true; + if (bias && bias->ndimension() > 1) { + valid = + (bias->sizes()[Layout::Parameter::width] == + weight.sizes()[Layout::Parameter::width]); } + return api::available() && valid; +} - vTensor v_output{ - context, - { - mat1_sizes[Layout::Parameter::height], - mat2_sizes[Layout::Parameter::width], - }, - self.options(), - }; +bool usable( + const Tensor& input, + const Tensor& weight, + const c10::optional& bias) { + return (input.sizes()[Layout::Parameter::width] == + weight.sizes()[Layout::Parameter::height]); +} - api::Command::Buffer command_buffer = context->command().pool.allocate(); - command_buffer.begin(); - { - if (v_self.has_image()) { - const struct { - float beta, alpha; - } block { - alpha.to(), - beta.to(), - }; +void addmm_impl( + api::Context* const context, + api::Command::Buffer& command_buffer, + vTensor& v_output, + const vTensor& v_self, + const vTensor& v_mat1, + const vTensor& v_mat2, + const float beta, + const float alpha) { + if (v_output.has_image() && v_self.has_image() && v_mat1.has_image() && + v_mat2.has_image()) { + const struct { + float alpha, beta; + } block{ + alpha, + beta, + }; - context->dispatch( - command_buffer, - { + context->dispatch( + command_buffer, + { VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, - }, - VK_KERNEL(addmm), - v_output.extents(), - // Write-only access bypasses synchronization but inserts appropriate - // barriers if necessary. - v_output.image(command_buffer, vTensor::Access::Write), - // Read-only access is implied on const tensors and triggers an async - // synchronization if necessary. - v_mat1.image(command_buffer), - // Read-only access is implied on const tensors and triggers an async - // synchronization if necessary. - v_mat2.image(command_buffer), - // Read-only access is implied on const tensors and triggers an async - // synchronization if necessary. - v_self.image(command_buffer), - context->resource().pool.uniform(block).object); - } else { - TORCH_CHECK(false, "Not implemented!"); - } + }, + VK_KERNEL(addmm), + v_output.extents(), + // Write-only access bypasses synchronization but inserts appropriate + // barriers if necessary. + v_output.image(command_buffer, vTensor::Access::Write), + // Read-only access is implied on const tensors and triggers an async + // synchronization if necessary. + v_mat1.image(command_buffer), + // Read-only access is implied on const tensors and triggers an async + // synchronization if necessary. + v_mat2.image(command_buffer), + // Read-only access is implied on const tensors and triggers an async + // synchronization if necessary. + v_self.image(command_buffer), + context->resource().pool.uniform(block).object); + } else { + TORCH_CHECK(false, "Not implemented!"); } - command_buffer.end(); - command_buffer.submit(context->gpu().queue); +} - return convert(v_output); +Tensor addmm( + const Tensor& self, + const Tensor& mat1, + const Tensor& mat2, + const Scalar beta, + const Scalar alpha) { + return LinearOpContext::create(api::context()->resource().pool, mat2, self) + .run(mat1, beta.to(), alpha.to()); } Tensor mm(const Tensor& self_arg, const Tensor& mat2_arg) { @@ -121,12 +134,12 @@ Tensor mm(const Tensor& self_arg, const Tensor& mat2_arg) { "Incompatible matrix dimensions!"); vTensor v_output{ - context, - { - mat1_sizes[Layout::Parameter::height], - mat2_sizes[Layout::Parameter::width], - }, - mat1.options(), + context, + { + mat1_sizes[Layout::Parameter::height], + mat2_sizes[Layout::Parameter::width], + }, + mat1.options(), }; api::Command::Buffer command_buffer = context->command().pool.allocate(); @@ -136,9 +149,9 @@ Tensor mm(const Tensor& self_arg, const Tensor& mat2_arg) { context->dispatch( command_buffer, { - VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, - VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, - VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, + VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, + VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, + VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, }, VK_KERNEL(mm), v_output.extents(), @@ -171,6 +184,100 @@ TORCH_LIBRARY_IMPL(aten, Vulkan, m) { #endif /* USE_VULKAN_API */ } // namespace + +LinearOpContext::LinearOpContext( + api::Resource::Pool& pool, + const Tensor& weight, + const c10::optional& bias) + : packed_{ + pack_weights(pool, weight), + pack_biases(pool, bias, weight), + }, + unpacked_{ + weight, + bias, + } { +} + +LinearOpContext LinearOpContext::create( + api::Resource::Pool& pool, + const Tensor& weight, + const c10::optional& bias) { + TORCH_CHECK(available(weight, bias)) + // Pass in the originals + return LinearOpContext{ + pool, + weight, + bias, + }; +} + +Tensor LinearOpContext::run(const Tensor& input_arg, float beta, float alpha) + const { + api::Context* const context = api::context(); + + const Tensor input = input_arg.is_vulkan() ? input_arg : input_arg.vulkan(); + const vTensor& v_input = convert(input); + + TORCH_CHECK( + usable(input, unpacked_.weight, unpacked_.bias), + "Vulkan Linear not usable! " + "Reason: The provided input tensor is either invalid or unsupported by Vulkan impl."); + + vTensor v_output{ + context, + { + input_arg.sizes()[Layout::Parameter::height], + packed_.v_weight.sizes()[Layout::Parameter::width], + }, + input.options(), + }; + + api::Command::Buffer command_buffer = context->command().pool.allocate(); + command_buffer.begin(); + { + if (input_arg.ndimension() == 2) { + addmm_impl( + context, + command_buffer, + v_output, + packed_.v_bias, + v_input, + packed_.v_weight, + beta, + alpha); + } else { + TORCH_CHECK( + false, "linear_run does not yet support inputs with ndim > 2!") + } + } + command_buffer.end(); + command_buffer.submit(context->gpu().queue); + + return convert(v_output); +} + +LinearOpContext::State LinearOpContext::unpack() const { + return LinearOpContext::State{ + unpacked_.weight, + unpacked_.bias, + }; +} + + +c10::intrusive_ptr linear_prepack( + Tensor&& weight, + c10::optional&& bias) { + return c10::make_intrusive(LinearOpContext::create( + persistent()->pool, std::move(weight), std::move(bias))); +} + +Tensor linear_run( + const Tensor& input, + const c10::intrusive_ptr& context) { + return context->run(input, 1.0, 1.0); +} + } // namespace ops } // namespace vulkan } // namespace native diff --git a/aten/src/ATen/native/vulkan/ops/Mm.h b/aten/src/ATen/native/vulkan/ops/Mm.h new file mode 100644 index 000000000000..08c84967d00f --- /dev/null +++ b/aten/src/ATen/native/vulkan/ops/Mm.h @@ -0,0 +1,55 @@ +#pragma once +#ifdef USE_VULKAN + +#include +#include + +namespace at { +namespace native { +namespace vulkan { +namespace ops { + +class LinearOpContext final : public torch::jit::CustomClassHolder { + public: + static LinearOpContext create( + api::Resource::Pool& pool, + const Tensor& weight, + const c10::optional& bias); + + using State = std::tuple>; + + Tensor run(const Tensor& input, float beta, float alpha) const; + State unpack() const; + + private: + LinearOpContext( + api::Resource::Pool& pool, + const Tensor& weight, + const c10::optional& bias); + + private: + struct { + vTensor v_weight; + vTensor v_bias; + } packed_; + + struct { + Tensor weight; + c10::optional bias; + } unpacked_; +}; + +c10::intrusive_ptr linear_prepack( + Tensor&& weight, + c10::optional&& bias); + +Tensor linear_run( + const Tensor& input, + const c10::intrusive_ptr& context); + +} // namespace ops +} // namespace vulkan +} // namespace native +} // namespace at + +#endif /* USE_VULKAN */ diff --git a/aten/src/ATen/native/vulkan/ops/RegisterOpContextClass.cpp b/aten/src/ATen/native/vulkan/ops/RegisterOpContextClass.cpp new file mode 100644 index 000000000000..699944b7c48e --- /dev/null +++ b/aten/src/ATen/native/vulkan/ops/RegisterOpContextClass.cpp @@ -0,0 +1,80 @@ +#ifdef USE_VULKAN + +#include +#include +#include +#include +#include + +namespace at { +namespace native { +namespace vulkan { +namespace ops { +namespace { + +TORCH_LIBRARY(vulkan, m) { + m.class_("Conv2dOpContext") + .def_pickle( + // __getstate__ + [](const c10::intrusive_ptr& context) { + return context->unpack(); + }, + // __setstate__ + [](Conv2dOpContext::State state) { + return conv2d_clamp_prepack( + std::move(std::get<0>(state)), + std::move(std::get<1>(state)), + std::move(std::get<2>(state)), + std::move(std::get<3>(state)), + std::move(std::get<4>(state)), + std::move(std::get<5>(state)), + std::move(std::get<6>(state)), + std::move(std::get<7>(state))); + }); + m.class_("LinearOpContext") + .def_pickle( + // __getstate__ + [](const c10::intrusive_ptr& context) { + return context->unpack(); + }, + // __setstate__ + [](LinearOpContext::State state) { + return linear_prepack( + std::move(std::get<0>(state)), std::move(std::get<1>(state))); + }); +} + +TORCH_LIBRARY(vulkan_prepack, m) { + m.def( + "conv2d_clamp_prepack(Tensor W, Tensor? B, int[2] stride, " + "int[2] padding, int[2] dilation, int groups, " + "Scalar? output_min=None, Scalar? output_max=None) " + "-> __torch__.torch.classes.vulkan.Conv2dOpContext"); + m.def( + "conv2d_clamp_run(Tensor X, " + "__torch__.torch.classes.vulkan.Conv2dOpContext W_prepack) -> Tensor Y"); + m.def( + "linear_prepack(Tensor W, Tensor? B) " + "-> __torch__.torch.classes.vulkan.LinearOpContext"); + m.def( + "linear_run(Tensor X, " + "__torch__.torch.classes.vulkan.LinearOpContext BW_prepack) -> Tensor Y"); +} + +TORCH_LIBRARY_IMPL(vulkan_prepack, CPU, m) { + m.impl("conv2d_clamp_prepack", TORCH_FN(conv2d_clamp_prepack)); + m.impl("linear_prepack", TORCH_FN(linear_prepack)); +} + +TORCH_LIBRARY_IMPL(vulkan_prepack, Vulkan, m) { + m.impl("conv2d_clamp_run", TORCH_FN(conv2d_clamp_run)); + m.impl("linear_run", TORCH_FN(linear_run)); +} + +} // namespace +} // namespace ops +} // namespace vulkan +} // namespace native +} // namespace at + +#endif /* USE_VULKAN */ diff --git a/torch/csrc/jit/passes/vulkan_rewrite.cpp b/torch/csrc/jit/passes/vulkan_rewrite.cpp index 0b4e90f3e1aa..4e381c47dae0 100644 --- a/torch/csrc/jit/passes/vulkan_rewrite.cpp +++ b/torch/csrc/jit/passes/vulkan_rewrite.cpp @@ -22,6 +22,51 @@ namespace jit { namespace { +void insertPrePackedLinearOp(std::shared_ptr& graph) { + // fuse decomposed linear into aten::linear + FuseLinear(graph); + + std::string linear_before_inline = R"( + graph(%linear, %input, %weight, %bias): + %r = prim::CallFunction(%linear, %input, %weight, %bias) + return (%r))"; + std::string prepacked_ops_pattern_before_inline = R"( + graph(%linear, %input, %weight, %bias): + %weight_t = aten::t(%weight) + %packed_weight_bias = vulkan_prepack::linear_prepack( + %weight_t, %bias) + %res = vulkan_prepack::linear_run(%input, %packed_weight_bias) + return (%res))"; + std::string linear_pattern = R"( + graph(%input, %weight, %bias): + %r = aten::linear(%input, %weight, %bias) + return (%r))"; + std::string prepacked_ops_pattern = R"( + graph(%input, %weight, %bias): + %weight_t = aten::t(%weight) + %packed_weight_bias = vulkan_prepack::linear_prepack( + %weight_t, %bias) + %res = vulkan_prepack::linear_run(%input, %packed_weight_bias) + return (%res))"; + + const auto filter = [](const Match& match, + const std::unordered_map& vmap) { + const auto& match_vmap = match.values_map; + const auto linear_value = match_vmap.at(vmap.at("linear")); + const auto func_name = graph_rewrite_helper::getFuncName(linear_value); + return (func_name == "linear"); + }; + + SubgraphRewriter linear_call_fn_rewriter; + linear_call_fn_rewriter.RegisterRewritePattern( + linear_before_inline, prepacked_ops_pattern_before_inline); + linear_call_fn_rewriter.runOnGraph(graph, filter); + + SubgraphRewriter linear_rewriter; + linear_rewriter.RegisterRewritePattern(linear_pattern, prepacked_ops_pattern); + linear_rewriter.runOnGraph(graph); +} + void insertPrePackedConv2dOp(std::shared_ptr& graph) { graph_rewrite_helper::replaceConvolutionWithAtenConv(graph); @@ -131,6 +176,7 @@ void fuseReluWithPackedOps(std::shared_ptr& graph) { } // namespace void vulkanInsertPrePackedOps(std::shared_ptr& graph) { + insertPrePackedLinearOp(graph); insertPrePackedConv2dOp(graph); } @@ -153,8 +199,10 @@ void vulkanFusePrePackedConvWithClamp(script::Module& module) { void vulkanFoldPrePackingOps(script::Module& m) { PrePackingOpsFilterFn filter_fn = [](const Node* n) -> bool { return ( - n->kind() == - Symbol::fromQualString("vulkan_prepack::conv2d_clamp_prepack")); + (n->kind() == + Symbol::fromQualString("vulkan_prepack::conv2d_clamp_prepack")) || + (n->kind() == + Symbol::fromQualString("vulkan_prepack::linear_prepack"))); }; PrePackingOpsFolder(m, filter_fn, "prepack_folding"); } From b6cb2caa68a3b74a5f4d74f22246e041bcc9c1ca Mon Sep 17 00:00:00 2001 From: Gregory Chanan Date: Thu, 12 Nov 2020 08:06:00 -0800 Subject: [PATCH 52/93] Revert "Fixed einsum compatibility/performance issues (#46398)" (#47821) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47821 This reverts commit a5c65b86ce249f5f2d365169e6315593fbd47b61. Conflicts: test/test_linalg.py Test Plan: Imported from OSS Reviewed By: mruberry Differential Revision: D24909923 Pulled By: gchanan fbshipit-source-id: 9dcf98e7c4a3c7e5aaffe475867fa086f3bb6ff2 --- aten/src/ATen/native/Linear.cpp | 498 +++++++++++++------------------- test/test_linalg.py | 139 --------- test/test_torch.py | 77 ++++- torch/functional.py | 166 +++++------ 4 files changed, 350 insertions(+), 530 deletions(-) diff --git a/aten/src/ATen/native/Linear.cpp b/aten/src/ATen/native/Linear.cpp index 6f66c7a120fe..c9e03aaa3b6b 100644 --- a/aten/src/ATen/native/Linear.cpp +++ b/aten/src/ATen/native/Linear.cpp @@ -136,331 +136,241 @@ static Tensor sumproduct_pair(const Tensor& left_, const Tensor& right_, IntArra return result; } -// There are roughly three parts to compute einsum: -// 1. Parse equation to extract the labels for each input operand and output -// 2. Unsqueeze missing dimensions from input operands and permute to align them -// 3. Compute result by multiplying input operands and summing contraction -// dimensions We do the last part by reducing to bmm. -Tensor einsum(std::string equation, TensorList operands) { - TORCH_CHECK(!operands.empty(), "einsum() must provide at least one operand"); - checkDeviceType("einsum()", operands, operands[0].device().type()); - - // Code for encoding ellipsis ("...") with labels - constexpr int ELLIPSIS = '.'; - - // Find arrow (->) to split equation into lhs and rhs - const auto arrow_pos = equation.find("->"); - - // Convert labels for input operands into an index in [0, 25] and store - // them in op_labels for each operand along with ELLIPSIS. - std::string lhs = equation.substr(0, arrow_pos); - std::vector> op_labels(operands.size()); - bool found_ell = false; - std::string::size_type curr_op = 0; - for (auto i = decltype(lhs.length()){0}; i < lhs.length(); ++i) { - switch (lhs[i]) { - case ' ': - // Ignore spaces - break; - - case '.': - TORCH_CHECK( - // Only one ellipsis per operand can be given - !found_ell, - "einsum() found \'.\' for operand ", - curr_op, - " for which an ellipsis was already found"); - TORCH_CHECK( - // Ensure it's a valid ellipsis - i + 2 < lhs.length() && lhs[++i] == '.' && lhs[++i] == '.', - "einsum() found \'.\' for operand ", - curr_op, - " that is not part of any ellipsis"); - op_labels[curr_op].push_back(ELLIPSIS); - found_ell = true; - break; - - case ',': - // Move onto next operand - ++curr_op; - TORCH_CHECK( - curr_op < operands.size(), - "einsum() fewer operands were provided than specified in the equation"); - found_ell = false; - break; - - default: - // Parse label - TORCH_CHECK( - lhs[i] >= 'a' && lhs[i] <= 'z', - "einsum() operand subscript must be in range [a, z] but found ", - lhs[i], - " for operand ", - curr_op); - // Convert label to index in [0, 25] and store - op_labels[curr_op].push_back(lhs[i] - 'a'); - } +Tensor einsum(std::string eqn, TensorList tensors) { + constexpr size_t number_of_letters = 26; + std::string in_eqn; + size_t pos; + // The equation is given in terms of single lowercase letters ('a'..'z') and potentially an ellipsis. + // Internally, we represent it using indices from 0 to num_total_dimensions, with each letter + // mapped to an index and the ellipsis ('...') being mapped to a number of consequtive indices. + // The mapping of letters to internal indices is given in letter_mapping. A value of -1 means that + // the letter has not been assigned an index yet (because it has not been seen). + // The ellipsis is defined by first_ell_idx (the first index) and num_ell_idxes (the number of indices). + // A value of -1 for num_ell_idxes specifies that we have not seen an ellipsis yet. + // Note: The internal indices are NOT the dimensions used internally. There is a mapping to them below. + + std::array letter_mapping; // map letter to internal (numerical) label + letter_mapping.fill(-1); + int64_t num_ell_idxes = -1; + int64_t first_ell_idx = 0; + + // The internal representation of the left hand side fo the equation (with ellipsis expanded) is stored in input_op_idxes. + // For each operand, we have a vector mapping each dimension to an internal index. + // We also keep track of the number of occurrences for each letter (to infer a right hand side if not given) and + // of the last occurrence of each index. + std::vector> input_op_idxes; // the parsed operand indices + std::array num_letter_occurrences; // number of occurrence in the equation of this letter + num_letter_occurrences.fill(0); + std::vector last_idx_occurrence; // the last operator (left to right) using this index + + if ((pos = eqn.find("->")) != std::string::npos) { // check whether we have a right hand side. in_eq is the left hand side + in_eqn = eqn.substr(0, pos); + } else { + in_eqn = eqn; } - - TORCH_CHECK( - curr_op == operands.size() - 1, - "einsum() more operands were provided than specified in the equation"); - - // Labels must be within [a, z]. - constexpr int total_labels = 'z' - 'a' + 1; - std::vector label_count(total_labels, 0); - - // The maximum number of dimensions covered by any ellipsis, needed when - // unsqueezing missing dimensions from operands to permute and broadcast - int64_t ell_num_dim = 0; - - // Compute label frequency and number of dimensions covered by ellipsis - // We do this after parsing labels to make it more readable and simpler - // to compute the number of dimensions covered by ellipsis. - for (std::size_t i = 0; i < operands.size(); ++i) { - Tensor operand = operands[i]; - std::vector labels = op_labels[i]; - int64_t nlabels = labels.size(); - int64_t ndims = operand.dim(); - bool has_ellipsis = false; - - for (int label : labels) { - if (label == ELLIPSIS) { - --nlabels; - has_ellipsis = true; - ell_num_dim = std::max(ell_num_dim, ndims - nlabels); - } else { - ++label_count[label]; + // remove spaces for einsum compatibility (#9929) + in_eqn.erase(std::remove_if(in_eqn.begin(), in_eqn.end(), isspace), in_eqn.end()); + + // next we parse in_eq (the left hand side) by iterating. It is a string of comma separated terms per index + int64_t operand = 0; + std::stringstream eqn_stream(in_eqn); + std::string term; + int64_t num_total_idxes = 0; + while (! eqn_stream.eof()) { + std::getline(eqn_stream, term, ','); // term = string with indices of current term + TORCH_CHECK((int64_t) tensors.size()>operand, "more operands in equation than tensors"); // we cannot have a longer equation than operands. We need to check here before we use the dimension + + int64_t ell_char_count = 0; // handling of ellipsis '...' is a bit tedious, we count the '.' + // if there is an ellipsis, the number of dimensions it represents must be total dim - letter dimensions + int64_t candidate_num_ell_idxes = tensors[operand].dim() - term.size() + 3; + int64_t dims_in_term = 0; // dimensions we have seen + std::vector current_op_idxes; // mapping of operand dimensions to indices for current term + for (auto &c : term) { // c = character with a single letter or '.' + if (c == '.') { + ell_char_count++; + TORCH_CHECK(ell_char_count <= 3, "can only have '.' in one ellispis '...' in term ", operand, " of the equation"); + if (ell_char_count == 3) { // this completes the ellipsis + if (num_ell_idxes == -1) { // if we have not seen an ellipsis before, keep track of indices and size + first_ell_idx = num_total_idxes; + num_ell_idxes = candidate_num_ell_idxes; + num_total_idxes += num_ell_idxes; + } + else { // we have seen an ellipsis before, so we check compatibility + TORCH_CHECK(candidate_num_ell_idxes == num_ell_idxes, + "ellipsis must represent ", num_ell_idxes, " dimensions in all terms"); + } + for (int64_t i = 0; i < num_ell_idxes; ++i) { // map ellipsis dimensions in operand to indices + current_op_idxes.push_back(first_ell_idx + i); + last_idx_occurrence.push_back(operand); + } + dims_in_term += num_ell_idxes; // keep track of dimensions + } + } else { // a letter (hopefully) + TORCH_CHECK((ell_char_count == 0) || (ell_char_count == 3), "'.' must only occur in ellipsis, operand ", operand); + TORCH_CHECK(('a' <= c) && (c <= 'z'), "only lowercase letters a-z allowed as indices"); + int64_t letter_num = c-'a'; // letter_num = position in letter_mapping + if (letter_mapping[letter_num] == -1) { // new letter, add internal index and mapping + letter_mapping[letter_num] = num_total_idxes; + num_total_idxes++; + last_idx_occurrence.push_back(operand); + } else { // letter we have already seen + last_idx_occurrence[letter_mapping[letter_num]] = operand; + } + num_letter_occurrences[letter_num]++; + current_op_idxes.push_back(letter_mapping[letter_num]); + dims_in_term++; } } - - TORCH_CHECK( - has_ellipsis ? nlabels <= ndims : nlabels == ndims, - "einsum() the number of subscripts in the equation (", - nlabels, - has_ellipsis ? ") is more than the number of dimensions (" - : ") does not match the number of dimensions (", - ndims, - ") for operand ", - i, - has_ellipsis ? "" : " and no ellipsis was given"); + TORCH_CHECK(dims_in_term == tensors[operand].dim(), "dimension mismatch for operand ", operand, ": equation ", dims_in_term, " tensor ", tensors[operand].dim()); + input_op_idxes.push_back(std::move(current_op_idxes)); + operand++; } - - // Mapping of label to index in the permuted tensors (out_dims + sum_dims) - // This will be used for aligning the dimensions of all input operands - std::vector label_perm_index(total_labels, -1); - - // Current index in the permuted shape - int perm_index = 0; - - // Start index of ellipsis dimensions in the permuted shape - int64_t ell_index = 0; - - if (arrow_pos == std::string::npos) { - // Implicit output is ellipsis (...) + labels seen only once - perm_index = ell_num_dim; - for (int label = 0; label < total_labels; ++label) { - if (label_count[label] == 1) { - label_perm_index[label] = perm_index++; + // in the check below, we need ==, but > is captured above, so the error message can be specific that it is <. + TORCH_CHECK((int64_t) tensors.size()==operand, "more tensors than operands in equation"); + + // the following parses or infers output (right hand side) + // it also assigns the idxes_to_preprocessed_dims (index -> dimension in preprocessed / output tensors) + // for the output indices. -1 means that the index has not been assigned a dimension yet + std::vector idxes_to_preprocessed_dims(num_total_idxes, -1); // the position of the index in the tensor dimensions + int64_t num_output_dims = 0; + if (pos != std::string::npos) { // parse the user provided right hand side + int64_t ell_char_count = 0; + for (auto &c : eqn.substr(pos+2)) { + if (c == '.') { // '.' as part of ellipsis + ell_char_count++; + TORCH_CHECK(ell_char_count <= 3, "can only have '.' in one ellispis '...' in right hand side of the equation"); + if (ell_char_count == 3) { // ellipsis complete + TORCH_CHECK(num_ell_idxes >= 0, "ellipsis '...' may only appear in right hand side if it does in left hand side"); + for (int64_t i = 0; i < num_ell_idxes; ++i) { + idxes_to_preprocessed_dims[first_ell_idx + i] = num_output_dims; + num_output_dims++; + } + } + } else if (! isspace(c)) { // letter (hopefully) + TORCH_CHECK((ell_char_count == 0) || (ell_char_count == 3), "'.' must only occur in ellipsis in the right hand side"); + TORCH_CHECK(('a' <= c) && (c <= 'z'), "only lowercase letters a-z allowed as indices"); + int64_t letter_num = c-'a'; + TORCH_CHECK(idxes_to_preprocessed_dims[letter_mapping[letter_num]] == -1, "index ", c, " occurs twice in output"); + idxes_to_preprocessed_dims[letter_mapping[letter_num]] = num_output_dims; + num_output_dims++; } } - } else { - // Parse explicit output - std::string rhs = equation.substr(arrow_pos + 2); - found_ell = false; - for (std::size_t i = 0; i < rhs.length(); ++i) { - switch (rhs[i]) { - case ' ': - // Ignore spaces - break; - - case '.': - TORCH_CHECK( - // There can only be one ellipsis in the output - !found_ell, - "einsum() found \'.\' for output but an ellipsis (...) was already found"); - TORCH_CHECK( - // Ensure ellipsis is correct - i + 2 < rhs.length() && rhs[++i] == '.' && rhs[++i] == '.', - "einsum() found \'.\' for output that is not part of any ellipsis (...)"); - ell_index = perm_index; - perm_index += ell_num_dim; - found_ell = true; - break; - - default: - TORCH_CHECK( - rhs[i] >= 'a' && rhs[i] <= 'z', - "einsum() subscripts must be in range [a, z] but found ", - rhs[i], - " for the output"); - TORCH_CHECK( - // Ensure label appeared at least once for some input operand and at - // most once for the output - label_count[rhs[i] - 'a'] > 0, - "einsum() output subscript ", - rhs[i], - label_count[rhs[i] - 'a'] == -1 - ? " appears more than once in the output" - : " does not appear in the equation for any input operand"); - label_perm_index[rhs[i] - 'a'] = perm_index++; - - // Set to -1 to mark that this label already appeared in the output - label_count[rhs[i] - 'a'] = -1; + } else { // create an inferred right hand side + // the ellipsis (if in the lhs) comes first + if (num_ell_idxes >= 0) { + for (int64_t i = 0; i < num_ell_idxes; ++i) { + idxes_to_preprocessed_dims[first_ell_idx + i] = num_output_dims; + num_output_dims++; + } + } + // then the indices that occur exactly once in alphabetic order + for (size_t idx = 0; idx < number_of_letters; idx++) { + if (num_letter_occurrences[idx] == 1) { + idxes_to_preprocessed_dims[letter_mapping[idx]] = num_output_dims; + num_output_dims++; } } - - TORCH_CHECK( - // Dimensions under ellipsis are not contracted, so ensure it appears in output - ell_num_dim <= 0 || found_ell, - "einsum() ellipsis (...) covering one or more dimensions was given in the input but not in the output"); } - - // Save output size before adding sum dims - int out_size = perm_index; - - // Add contraction labels (labels not present in output) - for (int label = 0; label < total_labels; ++label) { - if (label_count[label] > 0 && label_perm_index[label] == -1) { - label_perm_index[label] = perm_index++; + // now we assign the idxes_to_preprocessed_dims (index -> dimension in preprocessed / output tensors) + // for the non-output indices - those that are eventually summed over + int64_t position = num_output_dims; + for (int64_t i = 0; i < num_total_idxes; i++) { + if (idxes_to_preprocessed_dims[i]==-1) { + idxes_to_preprocessed_dims[i] = position; + position++; } } - // Here we unsqueeze missing dimensions to make all operands have the same - // number of dimensions. We take diagonals for repeated labels within the - // same operand. Finally we permute the operands to align dimensions as - // per the perm_out_index we computed above. - std::vector permuted_operands; - for (std::size_t i = 0; i < operands.size(); ++i) { - std::vector perm_shape(perm_index, -1); - std::vector label_dim(total_labels, -1); - std::vector labels = op_labels[i]; - Tensor operand = operands[i]; - std::size_t j = 0; - - for (int label : labels) { - if (label == ELLIPSIS) { - // Add missing dimensions under ellipsis - int64_t num_dim_diff = - ell_num_dim - (operand.dim() - labels.size() + 1); - for (int64_t k = 0; k < num_dim_diff; ++k) { - operand = operand.unsqueeze(j); + // we now "homogenize the dimensions", i.e. + // - take diagonals for duplicated indices + // - permute the dimensions to match the order given by idxes_to_preprocessed_dims + // - unsqueeze to create all dimensions for each index in each tensor where they are missing + // we also check that sizes match + // after this, all operands will have compatible shapes (i.e. all dimensions are aligned are broadcastable) + std::vector preprocessed_operands; + std::vector size_of_dims(num_total_idxes, -1); // keep track of sizes for each index, -1 means we have not seen a size yet + for (int64_t op = 0; op < (int64_t) tensors.size(); op++) { + auto preprocessed_op = tensors[op]; + std::vector idx_to_dim(num_total_idxes, -1); // the dimension which the index refers to in the original tensor, -1 means it does not appear + std::vector& current_op_input_idxes = input_op_idxes[op]; + int64_t dim = 0; // there are two dimension indices: dim is after taking diagonals, i is in input + for (size_t i = 0; i < current_op_input_idxes.size(); i++) { + auto idx = current_op_input_idxes[i]; + auto dim_out = idxes_to_preprocessed_dims[idx]; + if (idx_to_dim[dim_out] == -1) { // first appearance + idx_to_dim[dim_out] = dim; + if (size_of_dims[idx] == -1) { // keep track of sizes + size_of_dims[idx] = preprocessed_op.size(dim); } - for (int64_t k = 0; k < ell_num_dim; ++k) { - perm_shape[ell_index + k] = j++; + else { + TORCH_CHECK(size_of_dims[idx] == preprocessed_op.size(dim), "size of dimension does not match previous size, operand ", op, ", dim ", i); } - } else if (label_dim[label] != -1) { - // Repeated label, take diagonal - int64_t dim = label_dim[label]; - TORCH_CHECK( - operand.size(j) == operand.size(dim), - "einsum() subscript ", - char(label + 'a'), - " is repeated for operand ", - i, - " but the sizes don't match, ", - operand.size(j), - " != ", - operand.size(dim)); - operand = operand.diagonal(0, j, dim).movedim(-1, dim); - } else { - // Lookup output index for label - label_dim[label] = j; - perm_shape[label_perm_index[label]] = j++; + dim++; + } else { // duplicate dimension in tensor --> take diagonal of idx_to_dim[dim_out] and dim and put the diagonal dimension to idx_to_dim[dim_out] + TORCH_CHECK(size_of_dims[idx] == preprocessed_op.size(dim), "size of dimension does not match previous size, operand ", op, ", dim ", i); + preprocessed_op = preprocessed_op.diagonal(0, idx_to_dim[dim_out], dim); + // diagonal moves the diagonal dimension to the back + // now we permute the last dim back to idx_to_dim[dim_out] + std::vector perm(preprocessed_op.dim(), 0); + for (int64_t d = 0; d < preprocessed_op.dim(); d++) { + if (d == idx_to_dim[dim_out]) { + perm[d] = preprocessed_op.dim() - 1; + } else { + perm[d] = d - (d > idx_to_dim[dim_out]); + } + } + preprocessed_op = preprocessed_op.permute(perm); } } - - // Add dimensions for missing labels - for (int64_t& index : perm_shape) { - if (index == -1) { - operand = operand.unsqueeze(-1); - index = j++; + // now we permute the dimensions in the right order + std::vector permutation; // permutation for this tensor + for (auto &d : idx_to_dim) { + if (d > -1) { + permutation.push_back(d); } } - - permuted_operands.push_back(operand.permute(perm_shape)); - } - - // Check if operands broadcast and keep track of last operand with - // dimension size != 1 for optimizing reductions - std::vector dim_last_op(perm_index, 0); - bool has_zero_size_dim = false; - for (int dim = 0; dim < perm_index; ++dim) { - int64_t broadcast_size = permuted_operands[0].size(dim); - for (std::size_t i = 1; i < permuted_operands.size(); ++i) { - int64_t dim_size = permuted_operands[i].size(dim); - if (broadcast_size != dim_size && broadcast_size != 1 && dim_size != 1) { - std::ostringstream msg; - msg << "einsum() operands do not broadcast with remapped shapes [original->remapped]:"; - for (std::size_t j = 0; j < operands.size(); ++j) { - msg << " " << operands[j].sizes() << "->" - << permuted_operands[j].sizes(); - } - TORCH_CHECK(false, msg.str()); - } - if (dim_size != 1) { - broadcast_size = dim_size; - dim_last_op[dim] = i; + preprocessed_op = preprocessed_op.permute(permutation); + // finally, we insert dimensions for idxes not in the operand + for (size_t dim = 0; dim < idx_to_dim.size(); dim++) { + if (idx_to_dim[dim] == -1) { + preprocessed_op = preprocessed_op.unsqueeze(dim); } } - has_zero_size_dim |= broadcast_size == 0; - } - // Compute result - Tensor result = permuted_operands[0]; - - // Fast path for when an operand has zero sized dim - if (has_zero_size_dim) { - std::vector out_shape(out_size); - for (int i = 0; i < out_size; ++i) { - out_shape[i] = permuted_operands[dim_last_op[i]].size(i); - } - return at::zeros(out_shape, result.options()); + preprocessed_operands.push_back(std::move(preprocessed_op)); } - // Sum out or squeeze dimensions that are size 1 for all later operands - int dim = out_size; - for (int i = dim; i < perm_index; ++i, ++dim) { - if (dim_last_op[i] == 0) { - if (result.size(dim) == 1) { - result = result.squeeze(dim--); - } else { - result = result.sum(dim--); - } + // now we reduce the indices from left to right + // numpy allows to optimize the path using various + // algorithms (see eigen_path in numpy docs) + // we start with the leftmost operator and reduce indices that + // appear only there + Tensor result = std::move(preprocessed_operands[0]); + for (int64_t idx = 0; idx < num_total_idxes; idx++) { + if ((last_idx_occurrence[idx] == 0) + && (idxes_to_preprocessed_dims[idx]>=num_output_dims)) { + result = result.sum(idxes_to_preprocessed_dims[idx], true); } } - for (std::size_t i = 1; i < permuted_operands.size(); ++i) { - Tensor operand = permuted_operands[i]; + // now we process each tensor using sumproduct_pair + for (int64_t i = 1; i < (int64_t) preprocessed_operands.size(); i++) { std::vector sum_dims; - - // Sum out or squeeze dimensions that are size 1 for all later operands - dim = out_size; - for (int j = dim; j < perm_index; ++j, ++dim) { - if (dim_last_op[j] < i) { - operand = operand.squeeze(dim); - --dim; - } else if (dim_last_op[j] == i) { - if (result.size(dim) == 1) { - operand = operand.sum(dim); - result = result.squeeze(dim); - --dim; - } else { - sum_dims.push_back(dim); - } + for (int64_t idx = 0; idx < num_total_idxes; idx++) { + if ((last_idx_occurrence[idx] == i) + && (idxes_to_preprocessed_dims[idx]>=num_output_dims)) { + sum_dims.push_back(idxes_to_preprocessed_dims[idx]); } } - - // Multiply tensors and sum out dimensions in sum_dims - if (sum_dims.empty()) { - result = result.mul(operand); - } else if (sum_dims.size() == result.sizes().size()) { - result = result.flatten().dot(operand.flatten()); - } else { - result = sumproduct_pair(result, operand, sum_dims, false); - } + result = at::native::sumproduct_pair(result, std::move(preprocessed_operands[i]), sum_dims, true); + } + // finally, we squeeze out all non-result dimensions + auto sizes = result.sizes().vec(); + for (int64_t dim = num_total_idxes-1; dim >= num_output_dims; dim--) { + sizes.erase(sizes.begin() + dim); } + result = result.view(sizes); return result; } diff --git a/test/test_linalg.py b/test/test_linalg.py index a6a7fa9a8088..67bf450a006c 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -1324,145 +1324,6 @@ def test_dot_invalid_args(self, device): self._test_dot_vdot_invalid_args(device, torch.dot) self._test_dot_vdot_invalid_args(device, torch.dot, complex_dtypes=True) - def test_einsum(self, device): - def check(equation, *operands): - ref = np.einsum(equation, *[operand.cpu().numpy() for operand in operands]) - res = torch.einsum(equation, operands) - self.assertEqual(res.cpu(), torch.from_numpy(np.array(ref))) - - # Autograd check (FIXME: tests below fail check) - if equation not in {"i,i->", "i,i->i", "ij,ij->ij"}: - ops = [op.detach().requires_grad_() for op in operands] - self.assertTrue(torch.autograd.gradcheck(lambda *ops: torch.einsum(equation, ops), ops)) - for op in ops: - self.assertTrue(op._version == 0) - - # Test cases from https://gist.github.com/rockt/15ee013889d65342088e9260a377dc8f - x = torch.rand(5, device=device) - y = torch.rand(7, device=device) - A = torch.randn(3, 5, device=device) - B = torch.randn(2, 5, device=device) - C = torch.randn(2, 3, 5, device=device) - D = torch.randn(2, 5, 7, device=device) - E = torch.randn(7, 9, device=device) - F = torch.randn(2, 3, 3, 5, device=device) - G = torch.randn(5, 4, 6, device=device) - H = torch.randn(4, 4, device=device) - I = torch.rand(2, 3, 2, device=device) - - # Vector operations - check('i->', x) # sum - check('i,i->', x, x) # dot - check('i,i->i', x, x) # vector element-wisem mul - check('i,j->ij', x, y) # outer - - # Matrix operations - check("ij->ji", A) # transpose - check("ij->j", A) # row sum - check("ij->i", A) # col sum - check("ij,ij->ij", A, A) # matrix element-wise mul - check("ij,j->i", A, x) # matrix vector multiplication - check("ij,kj->ik", A, B) # matmul - check("ij,ab->ijab", A, E) # matrix outer product - - # Tensor operations - check("aij,ajk->aik", C, D) # batch matmul - check("ijk,jk->i", C, A) # tensor matrix contraction - check("aij,jk->aik", D, E) # tensor matrix contraction - check("abcd,dfg->abcfg", F, G) # tensor tensor contraction - check("ijk,jk->ik", C, A) # tensor matrix contraction with double indices - check("ijk,jk->ij", C, A) # tensor matrix contraction with double indices - check("ijk,ik->j", C, B) # non contiguous - check("ijk,ik->jk", C, B) # non contiguous with double indices - - # Test diagonals - check("ii", H) # trace - check("ii->i", H) # diagonal - check('iji->j', I) # non-contiguous trace - - # Test ellipsis - check("i...->...", H) - check("ki,...k->i...", A.t(), B) - check("k...,jk", A.t(), B) - check('...ik, ...kj -> ...ij', torch.rand(2, 3, 4), torch.rand(1, 5)) - check('bik,k...j->i...j', torch.rand(5, 2, 3), torch.rand(3, 2)) - check('i...j, ij... -> ...ij', torch.rand(2, 3, 4), torch.rand(2, 4, 2, 3)) - - # torch.bilinear with discontiguous tensors - l = torch.randn(10, 5, device=device).transpose(0, 1) - r = torch.randn(20, 5, device=device).transpose(0, 1) - w = torch.randn(15, 10, 20, device=device) - check("bn,anm,bm->ba", l, w, r) - # with strided tensors - check("bn,anm,bm->ba", l[:, ::2], w[:, ::2, ::2], r[:, ::2]) - - def test_einsum_corner_cases(self, device): - def check(equation, *operands, expected_output): - tensors = [torch.tensor(operand, dtype=torch.float32, device=device) if not isinstance(operand, tuple) - else torch.rand(operand, dtype=torch.float32, device=device) for operand in operands] - output = torch.einsum(equation, tensors) - self.assertEqual(output, torch.tensor(expected_output, dtype=torch.float32, device=device)) - - # Test equation variantions - check(' ', 1, expected_output=1) - check(' -> ', 1, expected_output=1) - check(' , ', 2, 2, expected_output=4) - check(' , , ', 2, 2, 2, expected_output=8) - check(' , -> ', 2, 2, expected_output=4) - check(' i ', [1], expected_output=[1]) - check(' i -> ', [1], expected_output=1) - check(' i -> i ', [1], expected_output=[1]) - check(' i , i ', [2], [2], expected_output=4) - check(' i , i -> i ', [2], [2], expected_output=[4]) - - # Test tensors with 0 size dimensions - check('i', [], expected_output=[]) - check(' i j -> j', [[], []], expected_output=[]) - check('ij->i', [[], []], expected_output=[0., 0.]) - check(' i j k , k -> i j ', (3, 0, 6), (6,), expected_output=[[], [], []]) - - # Test broadcasting - check('i,j', [2], [1, 2], expected_output=[[2, 4]]) - check('i,ij->ij', [1, 2], [[1, 2, 3], [2, 3, 4]], expected_output=[[1, 2, 3], [4, 6, 8]]) - - # Test ellipsis broadcasting - check('...', 1, expected_output=1) - check('...->', 1, expected_output=1) - check('...->...', 1, expected_output=1) - check('i...->i', [1], expected_output=[1]) - check('i...->...i', [1], expected_output=[1]) - - def test_einsum_error_cases(self, device): - def check(equation, operands, regex, exception=RuntimeError): - with self.assertRaisesRegex(exception, r'einsum\(\) ' + regex): - torch.einsum(equation, operands) - - x = torch.rand(2) - y = torch.rand(2, 3) - - check('', [], r'must provide at least one operand') - check('. ..', [x], r'found \'.\' for operand 0 that is not part of any ellipsis') - check('... ...', [x], r'found \'.\' for operand 0 for which an ellipsis was already found') - check('A', [x], r'operand subscript must be in range \[a, z\] but found A for operand 0') - check(',', [x], r'fewer operands were provided than specified in the equation') - check('', [x, x], r'more operands were provided than specified in the equation') - check('', [x], r'the number of subscripts in the equation \(0\) does not match the number ' - r'of dimensions \(1\) for operand 0 and no ellipsis was given') - check('ai', [x], r'the number of subscripts in the equation \(2\) does not match the number ' - r'of dimensions \(1\) for operand 0 and no ellipsis was given') - check('ai...', [x], r'the number of subscripts in the equation \(2\) is more than the number ' - r'of dimensions \(1\) for operand 0') - check('a->... .', [x], r'found \'.\' for output but an ellipsis \(...\) was already found') - check('a->..', [x], r'found \'.\' for output that is not part of any ellipsis \(...\)') - check('a->A', [x], r'subscripts must be in range \[a, z\] but found A for the output') - check('a->aa', [x], r'output subscript a appears more than once in the output') - check('a->i', [x], r'output subscript i does not appear in the equation for any input operand') - check('...->', [x], r'ellipsis \(...\) covering one or more dimensions was given in the input ' - r'but not in the output') - check('aa', [y], r'subscript a is repeated for operand 0 but the sizes don\'t match, 3 != 2') - check('a, ba', [x, y], r'operands do not broadcast with remapped shapes \[original->remapped\]: ' - r'\[2\]->\[1, 2\] \[2, 3\]->\[2, 3\]') - def triangular_solve_test_helper(self, A_dims, b_dims, upper, unitriangular, device, dtype): triangle_function = torch.triu if upper else torch.tril diff --git a/test/test_torch.py b/test/test_torch.py index c592411fe25a..1067c56375eb 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -39,7 +39,7 @@ onlyCUDA, onlyCPU, \ dtypes, dtypesIfCUDA, dtypesIfCPU, deviceCountAtLeast, skipCUDAIf, precisionOverride, \ PYTORCH_CUDA_MEMCHECK, largeTensorTest, onlyOnCPUAndCUDA, expectedAlertNondeterministic -from typing import Dict, List +from typing import Dict, List, Tuple, Union import torch.backends.quantized import torch.testing._internal.data from torch.testing._internal.common_cuda import tf32_on_and_off, tf32_is_not_fp32, with_tf32_off @@ -16072,6 +16072,81 @@ def test_helper(min, max): test_helper(torch.finfo(dtype).tiny, torch.finfo(dtype).max) + @onlyCPU + @slowTest + @unittest.skipIf(not TEST_NUMPY, 'Numpy not found') + @dtypes(torch.double) + def test_einsum(self, device: torch.device, dtype: torch.dtype) -> None: + # test cases taken from https://gist.github.com/rockt/15ee013889d65342088e9260a377dc8f + x = torch.randn(5, dtype=dtype, device=device) + y = torch.randn(7, dtype=dtype, device=device) + A = torch.randn(3, 5, dtype=dtype, device=device) + B = torch.randn(2, 5, dtype=dtype, device=device) + C = torch.randn(2, 3, 5, dtype=dtype, device=device) + D = torch.randn(2, 5, 7, dtype=dtype, device=device) + E = torch.randn(7, 9, dtype=dtype, device=device) + F = torch.randn(2, 3, 5, 7, dtype=dtype, device=device) + G = torch.randn(7, 11, 13, dtype=dtype, device=device) + H = torch.randn(4, 4, dtype=dtype, device=device) + I = torch.randn(3, 4, 4, dtype=dtype, device=device) + l = torch.randn(5, 10, dtype=dtype, device=device) + r = torch.randn(5, 20, dtype=dtype, device=device) + w = torch.randn(30, 10, 20, dtype=dtype, device=device) + test_list: List[Union[Tuple[str, torch.Tensor], + Tuple[str, torch.Tensor, torch.Tensor], + Tuple[str, torch.Tensor, torch.Tensor, torch.Tensor]]] = [ + # -- Vector + ("i->", x), # sum + ("i,i->", x, x), # dot + ("i,i->i", x, x), # vector element-wise mul + ("i,j->ij", x, y), # outer + # -- Matrix + ("ij->ji", A), # transpose + ("ij->j", A), # row sum + ("ij->i", A), # col sum + ("ij,ij->ij", A, A), # matrix element-wise mul + ("ij,j->i", A, x), # matrix vector multiplication + ("ij,kj->ik", A, B), # matmul + ("ij,ab->ijab", A, E), # matrix outer product + # -- Tensor + ("aij,ajk->aik", C, D), # batch matmul + ("ijk,jk->i", C, A), # tensor matrix contraction + ("aij,jk->aik", D, E), # tensor matrix contraction + ("abcd,dfg->abcfg", F, G), # tensor tensor contraction + ("ijk,jk->ik", C, A), # tensor matrix contraction with double indices + ("ijk,jk->ij", C, A), # tensor matrix contraction with double indices + ("ijk,ik->j", C, B), # non contiguous + ("ijk,ik->jk", C, B), # non contiguous with double indices + # -- Diagonal + ("ii", H), # trace + ("ii->i", H), # diagonal + # -- Ellipsis + ("i...->...", H), + ("ki,...k->i...", A.t(), B), + ("k...,jk", A.t(), B), + ("...ii->...i", I), # batch diagonal + # -- Other + ("bn,anm,bm->ba", l, w, r), # as torch.bilinear + ("... ii->...i ", I), # batch diagonal with spaces + ] + for test in test_list: + actual = torch.einsum(test[0], test[1:]) + expected = np.einsum(test[0], *[t.numpy() for t in test[1:]]) + self.assertEqual(expected.shape, actual.shape, msg=test[0]) + self.assertEqual(expected, actual, msg=test[0]) + # test vararg + actual2 = torch.einsum(test[0], *test[1:]) + self.assertEqual(expected.shape, actual2.shape, msg=test[0]) + self.assertEqual(expected, actual2, msg=test[0]) + + def do_einsum(*args): + return torch.einsum(test[0], args) + # FIXME: following test cases fail gradcheck + if test[0] not in {"i,i->", "i,i->i", "ij,ij->ij"}: + gradcheck_inps = tuple(t.detach().requires_grad_() for t in test[1:]) + self.assertTrue(torch.autograd.gradcheck(do_einsum, gradcheck_inps)) + self.assertTrue(A._version == 0) # check that we do not use inplace ops + @onlyCPU @dtypes(torch.bool, torch.double) def test_sum_all(self, device, dtype) -> None: diff --git a/torch/functional.py b/torch/functional.py index e26b4c1b4125..3781b73a178e 100644 --- a/torch/functional.py +++ b/torch/functional.py @@ -262,102 +262,76 @@ def lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True): def einsum(equation, *operands): r"""einsum(equation, *operands) -> Tensor - Sums the product of the elements of the input :attr:`operands` along dimensions specified using a notation - based on the Einstein summation convention. - - Einsum allows computing many common multi-dimensional linear algebraic array operations by representing them - in a short-hand format based on the Einstein summation convention, given by :attr:`equation`. The details of - this format are described below, but the general idea is to label every dimension of the input :attr:`operands` - with some subscript and define which subscripts are part of the output. The output is then computed by summing - the product of the elements of the :attr:`operands` along the dimensions whose subscripts are not part of the - output. For example, matrix multiplication can be computed using einsum as `torch.einsum("ij,jk->ik", A, B)`. - Here, j is the summation subscript and i and k the output subscripts (see section below for more details on why). - - Equation: - - The :attr:`equation` string specifies the subscripts (lower case letters `['a', 'z']`) for each dimension of - the input :attr:`operands` in the same order as the dimensions, separating subcripts for each operand by a - comma (','), e.g. `'ij,jk'` specify subscripts for two 2D operands. The dimensions labeled with the same subscript - must be broadcastable, that is, their size must either match or be `1`. The exception is if a subscript is - repeated for the same input operand, in which case the dimensions labeled with this subscript for this operand - must match in size and the operand will be replaced by its diagonal along these dimensions. The subscripts that - appear exactly once in the :attr:`equation` will be part of the output, sorted in increasing alphabetical order. - The output is computed by multiplying the input :attr:`operands` element-wise, with their dimensions aligned based - on the subscripts, and then summing out the dimensions whose subscripts are not part of the output. - - Optionally, the output subscripts can be explictly defined by adding an arrow ('->') at the end of the equation - followed by the subscripts for the output. For instance, the following equation computes the transpose of a - matrix multiplication: 'ij,jk->ki'. The output subscripts must appear at least once for some input operand and - at most once for the output. - - Ellipsis ('...') can be used in place of subscripts to broadcast the dimensions covered by the ellipsis. - Each input operand may contain at most one ellipsis which will cover the dimensions not covered by subscripts, - e.g. for an input operand with 5 dimensions, the ellipsis in the equation `'ab...c'` cover the third and fourth - dimensions. The ellipsis does not need to cover the same number of dimensions across the :attr:`operands` but the - 'shape' of the ellipsis (the size of the dimensions covered by them) must be broadcastable. In implicit mode, - the ellipsis will come first in the output. In explicit mode, if an ellipses covers at least one dimension then - it must appear in the output since the dimensions under the ellipsis cannot be summed over. e.g. the following - equation implements batch matrix multiplication `'...ij,...jk->...ik'`. - - A few final notes: the equation may contain whitespaces between the different elements (subscripts, ellipsis, - arrow and comma) but something like `'. . .'` is not valid. An empty string `''` is valid for scalar operands. - - .. note:: - - This function does not optimize the given expression, so a different formula for the same computation may - run faster or consume less memory. Projects like opt_einsum (https://optimized-einsum.readthedocs.io/en/stable/) - can optimize the formula for you. - - Args: - equation (string): The subscripts for the Einstein summation. - operands (Tensor): The operands to compute the Einstein sum of. - - Examples:: - - # trace - >>> torch.einsum('ii', torch.randn(4, 4)) - tensor(-1.2104) - - # diagonal - >>> torch.einsum('ii->i', torch.randn(4, 4)) - tensor([-0.1034, 0.7952, -0.2433, 0.4545]) - - # outer product - >>> x = torch.randn(5) - >>> y = torch.randn(4) - >>> torch.einsum('i,j->ij', x, y) - tensor([[ 0.1156, -0.2897, -0.3918, 0.4963], - [-0.3744, 0.9381, 1.2685, -1.6070], - [ 0.7208, -1.8058, -2.4419, 3.0936], - [ 0.1713, -0.4291, -0.5802, 0.7350], - [ 0.5704, -1.4290, -1.9323, 2.4480]]) - - # batch matrix multiplication - >>> As = torch.randn(3,2,5) - >>> Bs = torch.randn(3,5,4) - >>> torch.einsum('bij,bjk->bik', As, Bs) - tensor([[[-1.0564, -1.5904, 3.2023, 3.1271], - [-1.6706, -0.8097, -0.8025, -2.1183]], - - [[ 4.2239, 0.3107, -0.5756, -0.2354], - [-1.4558, -0.3460, 1.5087, -0.8530]], - - [[ 2.8153, 1.8787, -4.3839, -1.2112], - [ 0.3728, -2.1131, 0.0921, 0.8305]]]) - - # batch permute - >>> A = torch.randn(2, 3, 4, 5) - >>> torch.einsum('...ij->...ji', A).shape - torch.Size([2, 3, 5, 4]) - - # equivalent to torch.nn.functional.bilinear - >>> A = torch.randn(3,5,4) - >>> l = torch.randn(2,5) - >>> r = torch.randn(2,4) - >>> torch.einsum('bn,anm,bm->ba', l, A, r) - tensor([[-0.3430, -5.2405, 0.4494], - [ 0.3311, 5.5201, -3.0356]]) - """ +This function provides a way of computing multilinear expressions (i.e. sums of products) using the +Einstein summation convention. + +Args: + equation (string): The equation is given in terms of lower case letters (indices) to be associated + with each dimension of the operands and result. The left hand side lists the operands + dimensions, separated by commas. There should be one index letter per tensor dimension. + The right hand side follows after `->` and gives the indices for the output. + If the `->` and right hand side are omitted, it implicitly defined as the alphabetically + sorted list of all indices appearing exactly once in the left hand side. + The indices not apprearing in the output are summed over after multiplying the operands + entries. + If an index appears several times for the same operand, a diagonal is taken. + Ellipses `...` represent a fixed number of dimensions. If the right hand side is inferred, + the ellipsis dimensions are at the beginning of the output. + operands (Tensor): The operands to compute the Einstein sum of. + +.. note:: + + This function does not optimize the given expression, so a different formula for the same computation may + run faster or consume less memory. Projects like opt_einsum (https://optimized-einsum.readthedocs.io/en/stable/) + can optimize the formula for you. + +Examples:: + + >>> x = torch.randn(5) + >>> y = torch.randn(4) + >>> torch.einsum('i,j->ij', x, y) # outer product + tensor([[-0.0570, -0.0286, -0.0231, 0.0197], + [ 1.2616, 0.6335, 0.5113, -0.4351], + [ 1.4452, 0.7257, 0.5857, -0.4984], + [-0.4647, -0.2333, -0.1883, 0.1603], + [-1.1130, -0.5588, -0.4510, 0.3838]]) + + + >>> A = torch.randn(3,5,4) + >>> l = torch.randn(2,5) + >>> r = torch.randn(2,4) + >>> torch.einsum('bn,anm,bm->ba', l, A, r) # compare torch.nn.functional.bilinear + tensor([[-0.3430, -5.2405, 0.4494], + [ 0.3311, 5.5201, -3.0356]]) + + + >>> As = torch.randn(3,2,5) + >>> Bs = torch.randn(3,5,4) + >>> torch.einsum('bij,bjk->bik', As, Bs) # batch matrix multiplication + tensor([[[-1.0564, -1.5904, 3.2023, 3.1271], + [-1.6706, -0.8097, -0.8025, -2.1183]], + + [[ 4.2239, 0.3107, -0.5756, -0.2354], + [-1.4558, -0.3460, 1.5087, -0.8530]], + + [[ 2.8153, 1.8787, -4.3839, -1.2112], + [ 0.3728, -2.1131, 0.0921, 0.8305]]]) + + >>> A = torch.randn(3, 3) + >>> torch.einsum('ii->i', A) # diagonal + tensor([-0.7825, 0.8291, -0.1936]) + + >>> A = torch.randn(4, 3, 3) + >>> torch.einsum('...ii->...i', A) # batch diagonal + tensor([[-1.0864, 0.7292, 0.0569], + [-0.9725, -1.0270, 0.6493], + [ 0.5832, -1.1716, -1.5084], + [ 0.4041, -1.1690, 0.8570]]) + + >>> A = torch.randn(2, 3, 4, 5) + >>> torch.einsum('...ij->...ji', A).shape # batch permute + torch.Size([2, 3, 5, 4]) +""" if not torch.jit.is_scripting(): if any(type(t) is not Tensor for t in operands) and has_torch_function(operands): return handle_torch_function(einsum, operands, equation, *operands) From 00a3add425d305f79a413449d570e8f08e7a4c81 Mon Sep 17 00:00:00 2001 From: Yanan Cao Date: Thu, 12 Nov 2020 09:27:21 -0800 Subject: [PATCH 53/93] [TorchBind] Support using lambda function as TorchBind constructor (#47819) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47819 Reviewed By: wanchaol Differential Revision: D24910065 Pulled By: gmagogsfm fbshipit-source-id: ad5b4f67b0367e44fe486d31a060d9ad1e0cf568 --- .../jit/test_custom_class_registrations.cpp | 18 ++++++++++ test/jit/test_torchbind.py | 7 ++++ torch/custom_class.h | 33 +++++++++++++++++++ 3 files changed, 58 insertions(+) diff --git a/test/cpp/jit/test_custom_class_registrations.cpp b/test/cpp/jit/test_custom_class_registrations.cpp index f563120bbc6c..fc2d83d76409 100644 --- a/test/cpp/jit/test_custom_class_registrations.cpp +++ b/test/cpp/jit/test_custom_class_registrations.cpp @@ -33,6 +33,14 @@ struct Foo : torch::CustomClassHolder { } }; +struct LambdaInit : torch::CustomClassHolder { + int x, y; + LambdaInit(int x_, int y_) : x(x_), y(y_) {} + int64_t diff() { + return this->x - this->y; + } +}; + struct NoInit : torch::CustomClassHolder { int64_t x; }; @@ -202,6 +210,16 @@ TORCH_LIBRARY(_TorchScriptTesting, m) { .def("add", &Foo::add) .def("combine", &Foo::combine); + m.class_("_LambdaInit") + .def(torch::init([](int64_t x, int64_t y, bool swap) { + if (swap) { + return c10::make_intrusive(y, x); + } else { + return c10::make_intrusive(x, y); + } + })) + .def("diff", &LambdaInit::diff); + m.class_("_NoInit").def( "get_x", [](const c10::intrusive_ptr& self) { return self->x; }); diff --git a/test/jit/test_torchbind.py b/test/jit/test_torchbind.py index c1ca50270197..866170545747 100644 --- a/test/jit/test_torchbind.py +++ b/test/jit/test_torchbind.py @@ -338,3 +338,10 @@ def test_torchbind_attr_exception(self): foo = torch.classes._TorchScriptTesting._StackString(["test"]) with self.assertRaisesRegex(AttributeError, 'does not have a field'): foo.bar + + def test_lambda_as_constructor(self): + obj_no_swap = torch.classes._TorchScriptTesting._LambdaInit(4, 3, False) + self.assertEqual(obj_no_swap.diff(), 1) + + obj_swap = torch.classes._TorchScriptTesting._LambdaInit(4, 3, True) + self.assertEqual(obj_swap.diff(), -1) diff --git a/torch/custom_class.h b/torch/custom_class.h index 571a584294db..080d9d9d3c95 100644 --- a/torch/custom_class.h +++ b/torch/custom_class.h @@ -27,6 +27,21 @@ detail::types init() { return detail::types{}; } +template +struct InitLambda { + Func f; +}; + +template +decltype(auto) init(Func&& f) { + using InitTraits = + c10::guts::infer_function_traits_t>; + using ParameterTypeList = typename InitTraits::parameter_types; + + InitLambda init{std::forward(f)}; + return init; +} + /// Entry point for custom C++ class registration. To register a C++ class /// in PyTorch, instantiate `torch::class_` with the desired class as the /// template parameter. Typically, this instantiation should be done in @@ -95,6 +110,24 @@ class class_ { return *this; } + // Used in combination with torch::init([]lambda(){......}) + template + class_& def( + InitLambda> init, + std::string doc_string = "") { + auto init_lambda_wrapper = [func = std::move(init.f)]( + c10::tagged_capsule self, + ParameterTypes... arg) { + c10::intrusive_ptr classObj = + at::guts::invoke(func, std::forward(arg)...); + auto object = self.ivalue.toObject(); + object->setSlot(0, c10::IValue::make_capsule(classObj)); + }; + defineMethod("__init__", std::move(init_lambda_wrapper), std::move(doc_string)); + + return *this; + } + /// This is the normal method registration API. `name` is the name that /// the method will be made accessible by in Python and TorchScript. /// `f` is a callable object that defines the method. Typically `f` From 809660ffa49e04f7334a1fba45d4e27c05cfc837 Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Thu, 12 Nov 2020 09:51:21 -0800 Subject: [PATCH 54/93] ATen DerivedType is dead, long live ATen RegisterDispatchKey (#47011) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47011 smessmer has complained about how it is difficult to find generated code. Well hopefully this diffs helps a bit with that. There are three components to this refactor: - Rename TypeDerived (CPUType) to RegisterDispatchKey (RegisterCPU). The 'Type' nomenclature is vestigial and I think Register says what these files do a lot more clearly. I also got rid of the CPUType namespace; everything just goes in anonymous namespace now, less moving parts this way. - Give Math and DefaultBackend their own files (RegisterMath and RegisterDefaultBackend) - Restructure code generation so that schema definition is done completely separately from RegisterDispatchKey I decided to name the files RegisterCPU rather than the old convention BackendSelectRegister, because it seems better to me if these files clump together in an alphabetical listing rather than being spread out everywhere. There are a few manual registration files which should probably get similar renaming. I also did a little garden cleaning about how we identify if a dispatch key is a cuda key or a generic key (previously called KEYWORD_ALL_BACKENDS but I like my naming better). Signed-off-by: Edward Z. Yang Differential Revision: D24600806 Test Plan: Imported from OSS Reviewed By: smessmer Pulled By: ezyang fbshipit-source-id: c1b510dd7515bd95e3ad25b8edf961b2fb30a25a --- BUILD.bazel | 14 +- ...Register.cpp => RegisterBackendSelect.cpp} | 0 ...ypeDerived.cpp => RegisterDispatchKey.cpp} | 19 +-- .../{TypeDefault.cpp => RegisterSchema.cpp} | 16 +- c10/core/DispatchKey.h | 18 +- tools/codegen/gen.py | 158 ++++++++---------- 6 files changed, 90 insertions(+), 135 deletions(-) rename aten/src/ATen/templates/{BackendSelectRegister.cpp => RegisterBackendSelect.cpp} (100%) rename aten/src/ATen/templates/{TypeDerived.cpp => RegisterDispatchKey.cpp} (75%) rename aten/src/ATen/templates/{TypeDefault.cpp => RegisterSchema.cpp} (92%) diff --git a/BUILD.bazel b/BUILD.bazel index 7dc0e6d213fb..d009e690fd57 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -125,15 +125,17 @@ genrule( ] + glob(["aten/src/ATen/templates/**"]), outs = [ "aten/src/ATen/Declarations.yaml", - "aten/src/ATen/BackendSelectRegister.cpp", - "aten/src/ATen/CPUType.cpp", + "aten/src/ATen/RegisterBackendSelect.cpp", + "aten/src/ATen/RegisterCPU.cpp", + "aten/src/ATen/RegisterMkldnnCPU.cpp", + "aten/src/ATen/RegisterQuantizedCPU.cpp", + "aten/src/ATen/RegisterSparseCPU.cpp", + "aten/src/ATen/RegisterMath.cpp", + "aten/src/ATen/RegisterDefaultBackend.cpp", + "aten/src/ATen/RegisterSchema.cpp", "aten/src/ATen/Functions.h", "aten/src/ATen/Functions.cpp", "aten/src/ATen/NativeFunctions.h", - "aten/src/ATen/MkldnnCPUType.cpp", - "aten/src/ATen/QuantizedCPUType.cpp", - "aten/src/ATen/SparseCPUType.cpp", - "aten/src/ATen/TypeDefault.cpp", "aten/src/ATen/core/TensorBody.h", "aten/src/ATen/core/TensorMethods.cpp", "aten/src/ATen/core/ATenOpList.cpp", diff --git a/aten/src/ATen/templates/BackendSelectRegister.cpp b/aten/src/ATen/templates/RegisterBackendSelect.cpp similarity index 100% rename from aten/src/ATen/templates/BackendSelectRegister.cpp rename to aten/src/ATen/templates/RegisterBackendSelect.cpp diff --git a/aten/src/ATen/templates/TypeDerived.cpp b/aten/src/ATen/templates/RegisterDispatchKey.cpp similarity index 75% rename from aten/src/ATen/templates/TypeDerived.cpp rename to aten/src/ATen/templates/RegisterDispatchKey.cpp index 3275ab76ef62..0d935e77fe0b 100644 --- a/aten/src/ATen/templates/TypeDerived.cpp +++ b/aten/src/ATen/templates/RegisterDispatchKey.cpp @@ -33,21 +33,14 @@ namespace at { -/* example -Tensor * ${Type}::add(Tensor & a, Tensor & b) { - std::cout << "add Tensor with backend ${Backend}\n"; - return &a; -} -*/ - -namespace ${Type} { +namespace { -${type_derived_method_definitions} +${dispatch_definitions} -} // namespace ${Type} - -TORCH_LIBRARY_IMPL(aten, ${Backend}, m) { - ${function_registrations} +TORCH_LIBRARY_IMPL(aten, ${DispatchKey}, m) { + ${dispatch_registrations} } +} // anonymous namespace + } // namespace at diff --git a/aten/src/ATen/templates/TypeDefault.cpp b/aten/src/ATen/templates/RegisterSchema.cpp similarity index 92% rename from aten/src/ATen/templates/TypeDefault.cpp rename to aten/src/ATen/templates/RegisterSchema.cpp index 4cd1d1586d6a..a932bf3f87bc 100644 --- a/aten/src/ATen/templates/TypeDefault.cpp +++ b/aten/src/ATen/templates/RegisterSchema.cpp @@ -17,14 +17,8 @@ #include namespace at { -namespace TypeDefault { - -${type_method_definitions} - -} // namespace TypeDefault - TORCH_LIBRARY(aten, m) { - ${function_registrations}; + ${schema_registrations}; // String Ops // Implementations located in torch/csrc/jit/runtime/register_prim_ops.cpp @@ -63,12 +57,4 @@ TORCH_LIBRARY(aten, m) { // Implementations located in torch/csrc/jit/runtime/register_distributed_ops.cpp m.def("get_gradients(int context_id) -> Dict(Tensor, Tensor)"); } - -TORCH_LIBRARY_IMPL(aten, Math, m) { - ${math_function_registrations}; -} - -TORCH_LIBRARY_IMPL(aten, DefaultBackend, m) { - ${default_backend_function_registrations}; -} } // namespace at diff --git a/c10/core/DispatchKey.h b/c10/core/DispatchKey.h index aa4f11fe1439..4b6ca26757bc 100644 --- a/c10/core/DispatchKey.h +++ b/c10/core/DispatchKey.h @@ -51,8 +51,8 @@ enum class DispatchKey : uint8_t { // Here are backends which you think of as traditionally specifying // how to implement operations on some device. - CPU, // registered at build/aten/src/ATen/CPUType.cpp - CUDA, // registered at build/aten/src/ATen/CUDAType.cpp + CPU, // registered at build/aten/src/ATen/RegisterCPU.cpp + CUDA, // registered at build/aten/src/ATen/RegisterCUDA.cpp HIP, // NB: I think this is not actually used, due to Note [Masquerading as // CUDA] FPGA, // Xilinx support lives out of tree at https://gitlab.com/pytorch-complex/vitis_kernels @@ -73,8 +73,8 @@ enum class DispatchKey : uint8_t { // Here are backends which specify more specialized operators // based on the dtype of the tensor. - QuantizedCPU, // registered at build/aten/src/ATen/QuantizedCPUType.cpp - QuantizedCUDA, // registered at build/aten/src/ATen/QuantizedCUDAType.cpp + QuantizedCPU, // registered at build/aten/src/ATen/RegisterQuantizedCPU.cpp + QuantizedCUDA, // registered at build/aten/src/ATen/RegisterQuantizedCUDA.cpp ComplexCPU, // lives out of tree at // https://gitlab.com/pytorch-complex/pytorch-cpu-strided-complex ComplexCUDA, // and @@ -97,10 +97,10 @@ enum class DispatchKey : uint8_t { // based on the layout of the tensor. Note that the sparse backends // are one case where ordering matters: sparse multi-dispatches with // the corresponding dense tensors, and must be handled before them. - MkldnnCPU, // registered at build/aten/src/ATen/MkldnnCPUType.cpp + MkldnnCPU, // registered at build/aten/src/ATen/RegisterMkldnnCPU.cpp // NB: not to be confused with MKLDNN, which is Caffe2 only - SparseCPU, // registered at build/aten/src/ATen/SparseCPUType.cpp - SparseCUDA, // registered at build/aten/src/ATen/SparseCUDAType.cpp + SparseCPU, // registered at build/aten/src/ATen/RegisterSparseCPU.cpp + SparseCUDA, // registered at build/aten/src/ATen/RegisterSparseCUDA.cpp SparseHIP, // TODO: I think this is not actually used, due to Note // [Masquerading as CUDA] @@ -276,8 +276,8 @@ enum class DispatchKey : uint8_t { // See Note [Alias Dispatch Key : Autograd] Autograd, - Math, - DefaultBackend, + Math, // registered at build/aten/src/ATen/RegisterMath.cpp + DefaultBackend, // registered at build/aten/src/ATen/RegisterDefaultBackend.cpp // Define an alias key to represent end of alias dispatch keys. // If you add new alias keys after Autograd, please also update it here. diff --git a/tools/codegen/gen.py b/tools/codegen/gen.py index 0ed2dff543fe..da896ffce497 100644 --- a/tools/codegen/gen.py +++ b/tools/codegen/gen.py @@ -168,15 +168,31 @@ def cpp_string(s: str) -> str: # code we want. Target = Enum('Target', ('DEFINITION', 'DECLARATION', 'REGISTRATION')) -# Dispatch keywords in native_functions.yaml that support all backends. -KEYWORD_ALL_BACKENDS = ('DefaultBackend', 'Math') +# Dispatch keys that "support all backends". These codegen slightly differently +# then backend specific keys. +def is_generic_dispatch_key(dk: str) -> bool: + return dk in {'DefaultBackend', 'Math'} + +# CUDA specific dispatch keys +def is_cuda_dispatch_key(dk: str) -> bool: + return 'CUDA' in dk + +# Generates RegisterSchema.cpp. Depending on the selector, either +# all schemas are registered, or only some are (in the case of +# selective build) +@dataclass(frozen=True) +class RegisterSchema: + selector: SelectiveBuilder + + @method_with_native_function + def __call__(self, f: NativeFunction) -> Optional[str]: + op_name = f"aten::{f.func.name}" + if not self.selector.is_operator_selected(op_name): + return None + return f'm.def({cpp_string(str(f.func))});\n' -# Generates {dispatch}Type.cpp (e.g., CPUType.cpp). This function is also -# reused to implement per-operator registration. It also generates -# TypeDefault.cpp when dispatch target is for all backends (dispatch is None or -# dispatch in KEYWORD_ALL_BACKENDS). +# Generates Register{dispatch}.cpp (e.g., RegisterCPU.cpp). # -# {dispatch}Type.cpp # - The primary function of this file is to register all of the # implementations for the given dispatch key to the dispatcher, # so they are available for use in PyTorch. If dispatch is @@ -190,12 +206,9 @@ def cpp_string(s: str) -> str: # API without having to disambiguate which overload you want # (as would be the case if you directly registered native:: # functions). -# -# This function is also used for a secondary purpose: the registration -# logic is also reused to implement per-operator registration. @dataclass(frozen=True) -class ComputeTypeMethod: - dispatch: Optional[str] +class RegisterDispatchKey: + dispatch_key: str # TODO: Give more precise type Union[Literal[Target.DEFINITION, # Target.REGISTRATION]]; requires Literal from typing_extensions @@ -208,17 +221,14 @@ class ComputeTypeMethod: def __post_init__(self) -> None: assert self.target is not Target.DECLARATION - if self.dispatch is None: - assert self.target is Target.REGISTRATION @method_with_native_function def __call__(self, f: NativeFunction) -> Optional[str]: # for mypy type refinement; would be fixed by TODO on target assert self.target is not Target.DECLARATION - if self.dispatch is not None: - if self.dispatch not in f.dispatch: - return None + if self.dispatch_key not in f.dispatch: + return None op_name = f"aten::{f.func.name}" if self.target is Target.REGISTRATION and not self.selector.is_operator_selected(op_name): @@ -228,18 +238,16 @@ def __call__(self, f: NativeFunction) -> Optional[str]: returns_type = native.returns_type(f.func.returns) args = native.arguments(f.func) args_str = ', '.join(map(str, args)) - dispatch_to_all_backends = self.dispatch is not None and self.dispatch in KEYWORD_ALL_BACKENDS if self.target is Target.DEFINITION: - assert self.dispatch is not None - impl_name = f"at::native::{f.dispatch[self.dispatch]}" + impl_name = f"at::native::{f.dispatch[self.dispatch_key]}" args_exprs_str = ', '.join(a.name for a in args) return_kw = " return " cuda_guard = "" - if dispatch_to_all_backends or 'CUDA' in self.dispatch: + if is_generic_dispatch_key(self.dispatch_key) or is_cuda_dispatch_key(self.dispatch_key): self_args = (a for a in f.func.arguments if a.name == "self") # There is precedence for which argument we use to do @@ -264,9 +272,9 @@ def __call__(self, f: NativeFunction) -> Optional[str]: # TODO: There is probably a simpler version of this that # works just as well. - if f.device_guard and dispatch_to_all_backends and has_tensor_options: + if f.device_guard and is_generic_dispatch_key(self.dispatch_key) and has_tensor_options: cuda_guard = cuda_guard_from_tensor_options - elif f.device_guard and self.dispatch is not None and 'CUDA' in self.dispatch and has_tensor_options: + elif f.device_guard and is_cuda_dispatch_key(self.dispatch_key) and has_tensor_options: cuda_guard = f"""\ globalContext().lazyInitCUDA(); {cuda_guard_from_tensor_options} @@ -287,40 +295,21 @@ def __call__(self, f: NativeFunction) -> Optional[str]: """ elif self.target is Target.REGISTRATION: - if self.dispatch is None: - return f'm.def({cpp_string(str(f.func))});\n' - elif f.manual_kernel_registration: + if f.manual_kernel_registration: return None else: - if dispatch_to_all_backends: - type_name = f'TypeDefault::{name}' - else: - type_name = f'{self.dispatch}Type::{name}' - dispatcher_sig = DispatcherSignature.from_schema(f.func) # Figure out which signature the function is if local.use_c10_dispatcher() is UseC10Dispatcher.full: - payload = f"TORCH_FN({type_name})" + payload = f"TORCH_FN({name})" elif local.use_c10_dispatcher() is UseC10Dispatcher.hacky_wrapper_for_legacy_signatures: payload = "c10::impl::hacky_wrapper_for_legacy_signatures<" \ - f"{dispatcher_sig.type()}>(TORCH_FN({type_name}))" + f"{dispatcher_sig.type()}>(TORCH_FN({name}))" else: assert local.use_c10_dispatcher() is UseC10Dispatcher.with_codegenerated_unboxing_wrapper - payload = f"torch::CppFunction::makeUnboxedOnly(&{type_name})" - - # Annotate it with dispatch information if necessary - # - # NB: In the ordinary, TypeDerived code generation work flow, specification - # of the backend is handled by the enclosing block, so the torch::dispatch - # invocation here is strictly unnecessary. However, in the fbcode mobile - # only workflow using per-op registration, these registrations will get dumped - # in a TORCH_LIBRARY_FRAGMENT that does not have an ambient backend. So - # the torch::dispatch specification here is important! See - # Note [Redundancy in registration code is OK] for how we handle redundant info. - if self.dispatch is not None: - payload = f"torch::dispatch(DispatchKey::{self.dispatch},\n{payload})\n" + payload = f"torch::CppFunction::makeUnboxedOnly(&{name})" return f'm.impl("{f.func.name}",\n{payload});\n' else: @@ -456,7 +445,7 @@ def compute_native_function_declaration(f: NativeFunction) -> List[str]: return rs -# Generates BackendSelectRegister.cpp, a series of kernels which provide +# Generates RegisterBackendSelect.cpp, a series of kernels which provide # specialized computation of dispatch key for operator signatures which cannot # be easily done automatically using templating. @dataclass(frozen=True) @@ -790,7 +779,7 @@ def compute_registration_declarations(f: NativeFunction) -> str: 'schema': f'aten::{f.func}', # TODO: What exactly is the semantics of the 'dispatch' field? 'dispatch': str(f.dispatch.keys() != {'Math'}), - 'default': str(any(k in f.dispatch for k in KEYWORD_ALL_BACKENDS)) + 'default': str(any(is_generic_dispatch_key(k) for k in f.dispatch)) } return f"""{returns_type} {name}({args_str}); // {json.dumps(comment_data)} """ @@ -985,7 +974,9 @@ def make_file_manager(install_dir: str) -> FileManager: #include #include ''' - backends = [ + # NB: substrings in these dispatch keys matter, we do tests to see if + # a key contains, e.g., CUDA to classify it as a CUDA backend + dispatch_keys = [ "CPU", "SparseCPU", "MkldnnCPU", @@ -993,61 +984,50 @@ def make_file_manager(install_dir: str) -> FileManager: "SparseCUDA", "QuantizedCPU", "QuantizedCUDA", + "Math", + "DefaultBackend", ] if options.backend_whitelist: - backends = [b for b in backends if b in options.backend_whitelist] + dispatch_keys = [k for k in dispatch_keys if is_generic_dispatch_key(k) or k in options.backend_whitelist] - for dispatch in backends: - h_template = 'TypeDerived.h' - cpp_template = 'TypeDerived.cpp' + for dispatch_key in dispatch_keys: + cpp_template = 'RegisterDispatchKey.cpp' - fm = cuda_fm if 'CUDA' in dispatch else cpu_fm + fm = cuda_fm if is_cuda_dispatch_key(dispatch_key) else cpu_fm - fm.write_with_template(f'{dispatch}Type.cpp', cpp_template, lambda: { - 'Type': f'{dispatch}Type', - 'extra_cuda_headers': extra_cuda_headers if 'CUDA' in dispatch else '', + fm.write_with_template(f'Register{dispatch_key}.cpp', cpp_template, lambda: { + 'extra_cuda_headers': extra_cuda_headers if is_cuda_dispatch_key(dispatch_key) else '', 'legacy_th_headers': - '#include ' if dispatch == "CPU" else - '#include ' if dispatch == "CUDA" else + '#include ' if dispatch_key == "CPU" else + '#include ' if dispatch_key == "CUDA" else '', - 'Backend': dispatch, - 'type_derived_method_definitions': list(mapMaybe( - ComputeTypeMethod(dispatch, Target.DEFINITION, selector), + 'DispatchKey': dispatch_key, + 'dispatch_definitions': list(mapMaybe( + RegisterDispatchKey(dispatch_key, Target.DEFINITION, selector), native_functions )), - 'function_registrations': list(mapMaybe( - ComputeTypeMethod(dispatch, Target.REGISTRATION, selector), + 'dispatch_registrations': list(mapMaybe( + RegisterDispatchKey(dispatch_key, Target.REGISTRATION, selector), native_functions )), }) del fm + # BackendSelect is generated specially + cpu_fm.write('RegisterBackendSelect.cpp', lambda: { + 'backend_select_method_definitions': + list(mapMaybe(ComputeBackendSelect(Target.DEFINITION), native_functions)), + 'backend_select_function_registrations': + list(mapMaybe(ComputeBackendSelect(Target.REGISTRATION), native_functions)), + }) + schema_selector = selector if options.force_schema_registration: schema_selector = SelectiveBuilder.get_nop_selector() - - # TODO: split this file into separate files - cpu_fm.write('TypeDefault.cpp', lambda: { - 'type_method_definitions': - list(mapMaybe( - ComputeTypeMethod('Math', Target.DEFINITION, selector), - native_functions)) + - list(mapMaybe( - ComputeTypeMethod('DefaultBackend', Target.DEFINITION, selector), - native_functions)), - - 'function_registrations': list(mapMaybe( - ComputeTypeMethod(None, Target.REGISTRATION, schema_selector), - native_functions)), - - 'math_function_registrations': list(mapMaybe( - ComputeTypeMethod('Math', Target.REGISTRATION, selector), - native_functions)), - - 'default_backend_function_registrations': list(mapMaybe( - ComputeTypeMethod('DefaultBackend', Target.REGISTRATION, selector), - native_functions)), + cpu_fm.write('RegisterSchema.cpp', lambda: { + 'schema_registrations': list(mapMaybe(RegisterSchema(schema_selector), native_functions)), }) + cpu_fm.write('Functions.h', lambda: { 'function_declarations': list(mapMaybe(ComputeFunction(Target.DECLARATION), native_functions)), }) @@ -1066,12 +1046,6 @@ def make_file_manager(install_dir: str) -> FileManager: cpu_fm.write('NativeFunctions.h', lambda: { 'native_function_declarations': list(concatMap(compute_native_function_declaration, native_functions)), }) - cpu_fm.write('BackendSelectRegister.cpp', lambda: { - 'backend_select_method_definitions': - list(mapMaybe(ComputeBackendSelect(Target.DEFINITION), native_functions)), - 'backend_select_function_registrations': - list(mapMaybe(ComputeBackendSelect(Target.REGISTRATION), native_functions)), - }) cpu_fm.write('Declarations.yaml', lambda: format_yaml([compute_declaration_yaml(f) for f in native_functions])) cpu_fm.write('RegistrationDeclarations.h', lambda: { From d7c8d3cccb540bbcbfc537d0094ee3b73998ea04 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Thu, 12 Nov 2020 09:57:12 -0800 Subject: [PATCH 55/93] Remove references to `typing` module from setup.py (#47677) Summary: It is part of core Python-3.6.2+ Fixes https://github.com/pytorch/pytorch/issues/47596 Pull Request resolved: https://github.com/pytorch/pytorch/pull/47677 Reviewed By: walterddr Differential Revision: D24860188 Pulled By: malfet fbshipit-source-id: ad72b433a4493ebe5caca97c2e8a9d4b3c8172d4 --- .circleci/config.yml | 2 +- .circleci/docker/common/install_conda.sh | 5 ++--- .circleci/scripts/binary_ios_build.sh | 2 +- .circleci/verbatim-sources/job-specs/job-specs-custom.yml | 2 +- caffe2/requirements.txt | 1 - setup.py | 1 - 6 files changed, 5 insertions(+), 8 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index b5144dc703ea..2be90283d295 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -1579,7 +1579,7 @@ jobs: $* || (sleep 1 && $*) || (sleep 2 && $*) || (sleep 4 && $*) || (sleep 8 && $*) } - retry conda install numpy ninja pyyaml mkl mkl-include setuptools cmake cffi typing requests --yes + retry conda install numpy ninja pyyaml mkl mkl-include setuptools cmake cffi requests --yes # sync submodules cd ${PROJ_ROOT} diff --git a/.circleci/docker/common/install_conda.sh b/.circleci/docker/common/install_conda.sh index b7ad26f44836..f7fbcf2a6e6b 100755 --- a/.circleci/docker/common/install_conda.sh +++ b/.circleci/docker/common/install_conda.sh @@ -72,14 +72,13 @@ if [ -n "$ANACONDA_PYTHON_VERSION" ]; then # DO NOT install cmake here as it would install a version newer than 3.5, but # we want to pin to version 3.5. if [ "$ANACONDA_PYTHON_VERSION" = "3.8" ]; then - # DO NOT install typing if installing python-3.8, since its part of python-3.8 core packages # Install llvm-8 as it is required to compile llvmlite-0.30.0 from source conda_install numpy=1.18.5 pyyaml mkl mkl-include setuptools cffi future six llvmdev=8.0.0 elif [ "$ANACONDA_PYTHON_VERSION" = "3.7" ]; then # DO NOT install dataclasses if installing python-3.7, since its part of python-3.7 core packages - conda_install numpy=1.18.5 pyyaml mkl mkl-include setuptools cffi typing future six + conda_install numpy=1.18.5 pyyaml mkl mkl-include setuptools cffi future six else - conda_install numpy=1.18.5 pyyaml mkl mkl-include setuptools cffi typing future six dataclasses + conda_install numpy=1.18.5 pyyaml mkl mkl-include setuptools cffi future six dataclasses fi if [[ "$CUDA_VERSION" == 9.2* ]]; then conda_install magma-cuda92 -c pytorch diff --git a/.circleci/scripts/binary_ios_build.sh b/.circleci/scripts/binary_ios_build.sh index 1166b3a1bab7..4cfe778e5134 100644 --- a/.circleci/scripts/binary_ios_build.sh +++ b/.circleci/scripts/binary_ios_build.sh @@ -15,7 +15,7 @@ export PATH="~/anaconda/bin:${PATH}" source ~/anaconda/bin/activate # Install dependencies -conda install numpy ninja pyyaml mkl mkl-include setuptools cmake cffi typing requests --yes +conda install numpy ninja pyyaml mkl mkl-include setuptools cmake cffi requests --yes conda install -c conda-forge valgrind --yes export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"} diff --git a/.circleci/verbatim-sources/job-specs/job-specs-custom.yml b/.circleci/verbatim-sources/job-specs/job-specs-custom.yml index aacb45f41a52..d5f07eefb4e2 100644 --- a/.circleci/verbatim-sources/job-specs/job-specs-custom.yml +++ b/.circleci/verbatim-sources/job-specs/job-specs-custom.yml @@ -425,7 +425,7 @@ $* || (sleep 1 && $*) || (sleep 2 && $*) || (sleep 4 && $*) || (sleep 8 && $*) } - retry conda install numpy ninja pyyaml mkl mkl-include setuptools cmake cffi typing requests --yes + retry conda install numpy ninja pyyaml mkl mkl-include setuptools cmake cffi requests --yes # sync submodules cd ${PROJ_ROOT} diff --git a/caffe2/requirements.txt b/caffe2/requirements.txt index 7c0367da1d85..aa8d2be43aa5 100644 --- a/caffe2/requirements.txt +++ b/caffe2/requirements.txt @@ -2,4 +2,3 @@ numpy enum34 pyyaml requests -typing diff --git a/setup.py b/setup.py index 1b8998573fee..f49889fc688d 100644 --- a/setup.py +++ b/setup.py @@ -308,7 +308,6 @@ def check_file(f): 'benchmark', 'CMakeLists.txt')) check_pydep('yaml', 'pyyaml') - check_pydep('typing', 'typing') build_caffe2(version=version, cmake_python_library=cmake_python_library, From 9ea7a6c7c55484b83b8cedb58158f429425fe950 Mon Sep 17 00:00:00 2001 From: David Fan Date: Thu, 12 Nov 2020 10:15:10 -0800 Subject: [PATCH 56/93] [ONNX] Update ONNX doc for writing pytorch model (#46961) Summary: For tracing successfully, we need write pytorch model in torch way. So we add instructions with examples here. Pull Request resolved: https://github.com/pytorch/pytorch/pull/46961 Reviewed By: ailzhang Differential Revision: D24900040 Pulled By: bzinodev fbshipit-source-id: b375b533396b11dbc9656fa61e84a3f92f352e4b --- docs/source/onnx.rst | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/docs/source/onnx.rst b/docs/source/onnx.rst index 39f7c456353f..cdda93c60d3f 100644 --- a/docs/source/onnx.rst +++ b/docs/source/onnx.rst @@ -249,6 +249,31 @@ E.g.: :: out = model(*inputs) torch.onnx.export(model, inputs, 'loop_and_list.onnx', opset_version=11, example_outputs=out) +Write PyTorch model in Torch way +-------------------------------- + +PyTorch models can be written using numpy manipulations, but this is not proper when we convert to the ONNX model. +For the trace-based exporter, tracing treats the numpy values as the constant node, +therefore it calculates the wrong result if we change the input. +So the PyTorch model need implement using torch operators. +For example, do not use numpy operators on numpy tensors: :: + + np.concatenate((x, y, z), axis=1) + +do not convert to numpy types: :: + + y = x.astype(np.int) + +Always use torch tensors and torch operators: torch.concat, etc. +In addition, Dropout layer need defined in init function so that inferencing can handle it properly, i.e., :: + + class MyModule(nn.Module): + def __init__(self): + self.dropout = nn.Dropout(0.5) + + def forward(self, x): + x = self.dropout(x) + Indexing -------- From 66f9b1de1ba16e4230ddcbb1ff35c64023335328 Mon Sep 17 00:00:00 2001 From: Mingzhe Li Date: Thu, 12 Nov 2020 10:43:10 -0800 Subject: [PATCH 57/93] [NCCL] enable p2p tests (#47797) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47797 NCCL p2p tests had hang issues before, the reason is that there were some unexpected context switches. For example, process 1 which is supposed to only use GPU1 could use GPU0 as a result of missing explicitly setting device. ghstack-source-id: 116461969 Test Plan: waitforsandcastle Reviewed By: jiayisuse Differential Revision: D24863808 fbshipit-source-id: 92bd3a4874be8334210c7c8ee6363648893c963e --- torch/distributed/distributed_c10d.py | 4 ++++ torch/testing/_internal/distributed/distributed_test.py | 5 +++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index d97fa774ef30..fb61ab571e21 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -864,6 +864,10 @@ def batch_isend_irecv(p2p_op_list): >>> recv_tensor tensor([2, 3]) # Rank 0 tensor([0, 1]) # Rank 1 + + .. note:: Note that when this API is used with the NCCL PG backend, users must set + the current GPU device with `torch.cuda.set_device`, otherwise it will + lead to unexpected hang issues. """ _check_p2p_op_list(p2p_op_list) backend = get_backend(p2p_op_list[0].group) diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index 641ccd739ec7..8bd31c97c98f 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -588,7 +588,6 @@ def test_backend_full_group(self): # NCCL Batch SEND RECV @skip_if_no_gpu - @unittest.skip("NCCL P2P is not enabled for OSS builds") @unittest.skipIf(BACKEND != "nccl", "NCCL Batch Send Recv Only") @requires_nccl_version(2700, "Need NCCL 2.7+ for send/recv") def test_batch_isend_irecv_nccl(self): @@ -596,6 +595,7 @@ def test_batch_isend_irecv_nccl(self): rank = dist.get_rank() rank_to_GPU = self._init_multigpu_helper() device_id = rank_to_GPU[rank][0] + torch.cuda.set_device(device_id) p2p_op_list = [] for val in ["1", "0"]: @@ -640,7 +640,6 @@ def test_batch_isend_irecv_self_nccl(self): @skip_if_no_gpu @skip_if_small_worldsize - @unittest.skip("NCCL P2P is not enabled for OSS builds") @unittest.skipIf(BACKEND != "nccl", "NCCL Batch Send Recv Only") @requires_nccl_version(2700, "Need NCCL 2.7+ for send/recv") def test_batch_isend_irecv_no_rank_zero_nccl(self): @@ -648,6 +647,7 @@ def test_batch_isend_irecv_no_rank_zero_nccl(self): rank = dist.get_rank() rank_to_GPU = self._init_multigpu_helper() device_id = rank_to_GPU[rank][0] + torch.cuda.set_device(device_id) p2p_op_list = [] if rank == 1: @@ -794,6 +794,7 @@ def test_send_recv_nccl(self): rank = dist.get_rank() rank_to_GPU = self._init_multigpu_helper() device_id = rank_to_GPU[rank][0] + torch.cuda.set_device(device_id) tensor = _build_tensor(rank + 1, device_id=device_id) From 1478e5ec2aa42b2a9742257642c7c1d3203d7309 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Thu, 12 Nov 2020 10:54:39 -0800 Subject: [PATCH 58/93] [quant] Remove nn.quantized.ReLU module and nn.quantized.functional.relu (#47415) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47415 nn.ReLU works for both float and quantized input, we don't want to define an nn.quantized.ReLU that does the same thing as nn.ReLU, similarly for nn.quantized.functional.relu this also removes the numerical inconsistency for models quantizes nn.ReLU independently in qat mode Test Plan: Imported from OSS Reviewed By: z-a-f Differential Revision: D24747035 fbshipit-source-id: b8fdf13e513a0d5f0c4c6c9835635bdf9fdc2769 --- docs/source/quantization-support.rst | 2 - docs/source/torch.nn.quantized.rst | 11 +----- test/quantization/test_quantize.py | 26 ++++++++++++- .../quantization/test_quantized_functional.py | 2 +- test/quantization/test_quantized_module.py | 12 +++--- test/quantization/test_quantized_op.py | 28 -------------- torch/nn/quantized/functional.py | 17 --------- torch/nn/quantized/modules/__init__.py | 3 +- torch/nn/quantized/modules/activation.py | 38 ------------------- torch/nn/quantized/modules/batchnorm.py | 10 ++--- torch/nn/quantized/modules/conv.py | 10 +---- torch/nn/quantized/modules/linear.py | 4 +- .../quantization/fx/quantization_patterns.py | 23 ++--------- torch/quantization/quantization_mappings.py | 1 - torch/quantization/quantize.py | 4 +- torch/quantization/quantize_fx.py | 4 +- .../testing/_internal/common_quantization.py | 4 +- 17 files changed, 51 insertions(+), 148 deletions(-) diff --git a/docs/source/quantization-support.rst b/docs/source/quantization-support.rst index a1b13e1ada1f..60be24120d43 100644 --- a/docs/source/quantization-support.rst +++ b/docs/source/quantization-support.rst @@ -255,7 +255,6 @@ Quantized version of standard NN layers. * :class:`~torch.nn.quantized.Conv3d` — 3D convolution * :class:`~torch.nn.quantized.Linear` — Linear (fully-connected) layer * :class:`~torch.nn.MaxPool2d` — 2D max pooling -* :class:`~torch.nn.quantized.ReLU` — Rectified linear unit * :class:`~torch.nn.quantized.ReLU6` — Rectified linear unit with cut-off at quantized representation of 6 * :class:`~torch.nn.quantized.ELU` — ELU @@ -294,7 +293,6 @@ quantization output parameters) * :func:`~torch.nn.quantized.functional.interpolate` — Down-/up- sampler * :func:`~torch.nn.quantized.functional.linear` — Linear (fully-connected) op * :func:`~torch.nn.quantized.functional.max_pool2d` — 2D max pooling -* :func:`~torch.nn.quantized.functional.relu` — Rectified linear unit * :func:`~torch.nn.quantized.functional.elu` — ELU * :func:`~torch.nn.quantized.functional.hardsigmoid` — Hardsigmoid * :func:`~torch.nn.quantized.functional.hardswish` — Hardswish diff --git a/docs/source/torch.nn.quantized.rst b/docs/source/torch.nn.quantized.rst index a9aaa51a33bf..aeb3b55cd5fd 100644 --- a/docs/source/torch.nn.quantized.rst +++ b/docs/source/torch.nn.quantized.rst @@ -1,14 +1,12 @@ torch.nn.quantized ------------------ -This module implements the quantized versions of the nn layers such as -~`torch.nn.Conv2d` and `torch.nn.ReLU`. +This module implements the quantized versions of the nn modules and functionals. Functional interface ~~~~~~~~~~~~~~~~~~~~ .. automodule:: torch.nn.quantized.functional -.. autofunction:: relu .. autofunction:: linear .. autofunction:: conv1d .. autofunction:: conv2d @@ -25,11 +23,6 @@ Functional interface .. automodule:: torch.nn.quantized -ReLU -~~~~~~~~~~~~~~~ -.. autoclass:: ReLU - :members: - ReLU6 ~~~~~~~~~~~~~~~ .. autoclass:: ReLU6 @@ -119,5 +112,3 @@ InstanceNorm3d ~~~~~~~~~~~~~~~ .. autoclass:: InstanceNorm3d :members: - - diff --git a/test/quantization/test_quantize.py b/test/quantization/test_quantize.py index e0f09fab95d7..597dbb2c6dc0 100644 --- a/test/quantization/test_quantize.py +++ b/test/quantization/test_quantize.py @@ -307,8 +307,8 @@ def checkQuantized(model): self.checkQuantDequant(model.sub) self.checkQuantizedLinear(model.sub.module.fc1) self.checkQuantizedLinear(model.sub.module.fc2) - self.assertEqual(type(model.sub.module.relu1), nnq.ReLU) - self.assertEqual(type(model.sub.module.relu2), nnq.ReLU) + self.assertEqual(type(model.sub.module.relu1), nn.ReLU) + self.assertEqual(type(model.sub.module.relu2), nn.ReLU) self.checkScriptable(model, self.calib_data) self.checkNoQconfig(model) @@ -1248,6 +1248,9 @@ def forward(self, x): def test_leaky_relu(self): self._test_activation_op_impl(nn.LeakyReLU, nnq.LeakyReLU, {'negative_slope': 0.1, 'inplace': False}) + def test_relu(self): + self._test_activation_op_impl(nn.ReLU, nn.ReLU, {'inplace': False}) + class TestEagerModeQATOps(QuantizationTestCase): def _test_activation_convert_numerics_impl(self, Act, data): @@ -1325,6 +1328,25 @@ def test_leaky_relu(self): data = torch.randn(1, 3, 2, 4) self._test_activation_convert_numerics_impl(nn.LeakyReLU, data) + def test_relu(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.relu = nn.ReLU() + + def forward(self, x): + x = self.relu(x) + return x + + m = M().train() + m.qconfig = default_qconfig + m = prepare_qat(m) + # make sure no activation_post_process is inserted for relu + self.assertFalse(hasattr(m, "activation_post_process")) + m = convert(m) + # make sure ReLU module is not changed + self.assertTrue(type(m.relu), nn.ReLU) + class TestFunctionalModule(QuantizationTestCase): # Histogram Observers are slow, so have no-deadline to ensure test doesn't time out @given(train_mode=st.booleans()) diff --git a/test/quantization/test_quantized_functional.py b/test/quantization/test_quantized_functional.py index 548b0677efe0..59242493d869 100644 --- a/test/quantization/test_quantized_functional.py +++ b/test/quantization/test_quantized_functional.py @@ -26,7 +26,7 @@ def test_relu_api(self): zero_point = 1 qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point, dtype=torch.quint8) qY = torch.relu(qX) - qY_hat = qF.relu(qX) + qY_hat = F.relu(qX) self.assertEqual(qY, qY_hat) def _test_conv_api_impl( diff --git a/test/quantization/test_quantized_module.py b/test/quantization/test_quantized_module.py index 5780f3ffcbdf..60cb1f3b2480 100644 --- a/test/quantization/test_quantized_module.py +++ b/test/quantization/test_quantized_module.py @@ -39,7 +39,8 @@ class TestStaticQuantizedModule(QuantizationTestCase): def test_relu(self): - relu_module = nnq.ReLU() + relu_module = nn.ReLU() + # TODO: remove nnq.ReLU6 and remove this test relu6_module = nnq.ReLU6() x = torch.arange(-10, 10, dtype=torch.float) @@ -304,10 +305,11 @@ def _test_conv_api_impl( check_save_load=True) # Test from_float - conv_module.qconfig = torch.quantization.default_qconfig - torch.quantization.prepare(conv_module, inplace=True) - conv_module(X.float()) - converted_qconv_module = torch.nn.Sequential(conv_module) + fused_conv_module = torch.nn.intrinsic._FusedModule(conv_module) + fused_conv_module.qconfig = torch.quantization.default_qconfig + torch.quantization.prepare(fused_conv_module, inplace=True) + fused_conv_module(X.float()) + converted_qconv_module = fused_conv_module torch.quantization.convert(converted_qconv_module, inplace=True) # Smoke test to make sure the module actually runs diff --git a/test/quantization/test_quantized_op.py b/test/quantization/test_quantized_op.py index ee6a757a5c9e..12871f6b1aa0 100644 --- a/test/quantization/test_quantized_op.py +++ b/test/quantization/test_quantized_op.py @@ -216,34 +216,6 @@ def _test_activation_function(self, X, fn_name, test_configs): fn_name, q_op, qY, qY_hat )) - """Tests the correctness of the quantized::relu op.""" - @override_qengines - @given(X=hu.tensor(shapes=hu.array_shapes(1, 5, 1, 5), - qparams=hu.qparams())) - def test_qrelu(self, X): - relu_test_configs = [ - { - 'quantized_fn': [ - torch.relu, - torch.relu_, - torch.nn.functional.relu, - torch.nn.quantized.functional.relu, - ], - 'reference_fn': torch.nn.functional.relu - }, - { - 'quantized_fn': [ - torch.nn.functional.relu, - torch.nn.quantized.functional.relu, - ], - 'reference_fn': torch.nn.functional.relu, - 'extra_kwargs': { - 'inplace': True - } - } - ] - self._test_activation_function(X, 'relu', relu_test_configs) - """Tests the correctness of the quantized::relu6 op.""" @override_qengines @given(X=hu.tensor(shapes=hu.array_shapes(1, 5, 1, 5), diff --git a/torch/nn/quantized/functional.py b/torch/nn/quantized/functional.py index f2b090370ed3..ed9fd5a08ccc 100644 --- a/torch/nn/quantized/functional.py +++ b/torch/nn/quantized/functional.py @@ -410,23 +410,6 @@ def celu(input: Tensor, scale: float, zero_point: int, alpha: Optional[float] = return torch.ops.quantized.celu(input, scale, zero_point, alpha) -def relu(input: Tensor, inplace: bool = False) -> Tensor: - r"""relu(input, inplace=False) -> Tensor - - Applies the rectified linear unit function element-wise. - See :class:`~torch.nn.quantized.ReLU` for more details. - - Args: - input: quantized input - inplace: perform the computation inplace - """ - if not input.is_quantized: - raise ValueError("Input to 'quantized.relu' must be quantized!") - if inplace: - return torch.relu_(input) - else: - return torch.relu(input) - def leaky_relu(input: Tensor, negative_slope: float = 0.01, inplace: bool = False, scale: float = None, zero_point: int = None): r""" diff --git a/torch/nn/quantized/modules/__init__.py b/torch/nn/quantized/modules/__init__.py index a064c72dda98..f6e1a8af5d39 100644 --- a/torch/nn/quantized/modules/__init__.py +++ b/torch/nn/quantized/modules/__init__.py @@ -2,7 +2,7 @@ import torch from torch.nn.modules.pooling import MaxPool2d -from .activation import ReLU, ReLU6, Hardswish, ELU, LeakyReLU, Sigmoid +from .activation import ReLU6, Hardswish, ELU, LeakyReLU, Sigmoid from .batchnorm import BatchNorm2d, BatchNorm3d from .normalization import LayerNorm, GroupNorm, InstanceNorm1d, \ InstanceNorm2d, InstanceNorm3d @@ -106,7 +106,6 @@ def from_float(mod): 'Linear', 'MaxPool2d', 'Quantize', - 'ReLU', 'ReLU6', 'Sigmoid', # Wrapper modules diff --git a/torch/nn/quantized/modules/activation.py b/torch/nn/quantized/modules/activation.py index 366e1e63a039..234fd777d703 100644 --- a/torch/nn/quantized/modules/activation.py +++ b/torch/nn/quantized/modules/activation.py @@ -1,44 +1,6 @@ import torch import torch.nn.quantized.functional -class ReLU(torch.nn.ReLU): - r"""Applies quantized rectified linear unit function element-wise: - - :math:`\text{ReLU}(x)= \max(x_0, x)`, where :math:`x_0` is the zero point. - - Please see https://pytorch.org/docs/stable/nn.html#torch.nn.ReLU - for more documentation on ReLU. - - Args: - inplace: (Currently not supported) can optionally do the operation in-place. - - Shape: - - Input: :math:`(N, *)` where `*` means, any number of additional - dimensions - - Output: :math:`(N, *)`, same shape as the input - - Examples:: - - >>> m = nn.quantized.ReLU() - >>> input = torch.randn(2) - >>> input = torch.quantize_per_tensor(input, 1.0, 0, dtype=torch.qint32) - >>> output = m(input) - """ - def __init__(self, inplace=False): - super(ReLU, self).__init__(inplace) - self.inplace = inplace - - def forward(self, input): - return torch.nn.quantized.functional.relu(input, inplace=self.inplace) - - def _get_name(self): - return 'QuantizedReLU' - - @staticmethod - def from_float(mod): - return ReLU(mod.inplace) - - class ReLU6(torch.nn.ReLU): r"""Applies the element-wise function: diff --git a/torch/nn/quantized/modules/batchnorm.py b/torch/nn/quantized/modules/batchnorm.py index c3e028b191b4..189d402ee2a5 100644 --- a/torch/nn/quantized/modules/batchnorm.py +++ b/torch/nn/quantized/modules/batchnorm.py @@ -21,11 +21,9 @@ def _get_name(self): @classmethod def from_float(cls, mod): + activation_post_process = mod.activation_post_process if type(mod) == nni.BNReLU2d: - activation_post_process = mod[1].activation_post_process mod = mod[0] - else: - activation_post_process = mod.activation_post_process scale, zero_point = activation_post_process.calculate_qparams() new_mod = cls(mod.num_features, mod.eps) new_mod.weight = mod.weight @@ -36,6 +34,7 @@ def from_float(cls, mod): new_mod.zero_point = int(zero_point) return new_mod +# TODO: dedup with BatchNorm2d class BatchNorm3d(torch.nn.BatchNorm3d): r"""This is the quantized version of :class:`~torch.nn.BatchNorm3d`. """ @@ -55,12 +54,9 @@ def _get_name(self): @classmethod def from_float(cls, mod): + activation_post_process = mod.activation_post_process if type(mod) == nni.BNReLU3d: - activation_post_process = mod[1].activation_post_process mod = mod[0] - else: - activation_post_process = mod.activation_post_process - scale, zero_point = activation_post_process.calculate_qparams() new_mod = cls(mod.num_features, mod.eps) new_mod.weight = mod.weight diff --git a/torch/nn/quantized/modules/conv.py b/torch/nn/quantized/modules/conv.py index 6a2d98f0a8fe..a9ba3293630d 100644 --- a/torch/nn/quantized/modules/conv.py +++ b/torch/nn/quantized/modules/conv.py @@ -182,11 +182,9 @@ def from_float(cls, mod): cls._FLOAT_MODULE.__name__ assert hasattr(mod, "qconfig"), \ "Input float module must have qconfig defined." + activation_post_process = mod.activation_post_process if type(mod) == cls._NNI_CONV_RELU_MODULE: - activation_post_process = mod[1].activation_post_process mod = mod[0] - else: - activation_post_process = mod.activation_post_process weight_post_process = mod.qconfig.weight() return cls.get_qconv(mod, activation_post_process, weight_post_process) @@ -449,13 +447,9 @@ def from_float(cls, mod): cls._FLOAT_MODULE.__name__ assert hasattr(mod, 'qconfig'), \ 'Input float module must have qconfig defined.' - # Workaround for sequential, ConvReLU3d should probably inherit from - # Conv3d instead + activation_post_process = mod.activation_post_process if type(mod) == nni.ConvReLU3d: - activation_post_process = mod[1].activation_post_process mod = mod[0] - else: - activation_post_process = mod.activation_post_process return cls.get_qconv(mod, activation_post_process) # === Transposed Convolutions === diff --git a/torch/nn/quantized/modules/linear.py b/torch/nn/quantized/modules/linear.py index 4d27dad07bc1..d0ac86e020df 100644 --- a/torch/nn/quantized/modules/linear.py +++ b/torch/nn/quantized/modules/linear.py @@ -252,11 +252,9 @@ def from_float(cls, mod): assert type(mod) == cls._FLOAT_MODULE, ' nnq.' + cls.__name__ + '.from_float only works for ' + \ cls._FLOAT_MODULE.__name__ assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined' + activation_post_process = mod.activation_post_process if type(mod) == nni.LinearReLU: - activation_post_process = mod[1].activation_post_process mod = mod[0] - else: - activation_post_process = mod.activation_post_process weight_post_process = mod.qconfig.weight() weight_post_process(mod.weight) dtype = weight_post_process.dtype diff --git a/torch/quantization/fx/quantization_patterns.py b/torch/quantization/fx/quantization_patterns.py index 8b3c96306324..98263e80d3fb 100644 --- a/torch/quantization/fx/quantization_patterns.py +++ b/torch/quantization/fx/quantization_patterns.py @@ -212,14 +212,7 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_ convert_custom_config_dict = {} additional_static_quant_mapping = convert_custom_config_dict.get("static", {}) # 1. attach activation post process to module - if type(self.conv) in [ - torch.nn.intrinsic.ConvReLU1d, - torch.nn.intrinsic.ConvReLU2d, - torch.nn.intrinsic.ConvReLU3d - ]: - self.conv[1].activation_post_process = quantizer.activation_post_process_map[node.name] - else: - self.conv.activation_post_process = quantizer.activation_post_process_map[node.name] + self.conv.activation_post_process = quantizer.activation_post_process_map[node.name] # 2. select quantized class qconv_cls = get_static_quant_module_class( type(self.conv), additional_static_quant_mapping) @@ -315,11 +308,7 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_ output_activation_post_process = None if output_activation_post_process: - if type(self.linear) == torch.nn.intrinsic.LinearReLU: - float_linear_module = self.linear[1] - else: - float_linear_module = self.linear - float_linear_module.activation_post_process = output_activation_post_process + self.linear.activation_post_process = output_activation_post_process # 2. select corresponding quantized linear class for the float linear class if type(self.linear) in [torch.nn.Linear, torch.nn.qat.Linear]: @@ -416,13 +405,7 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_ convert_custom_config_dict = {} additional_static_quant_mapping = convert_custom_config_dict.get("static", {}) # 1. attach activation post process to module - activation_post_process = quantizer.activation_post_process_map[node.name] - if type(self.bn) in \ - [torch.nn.intrinsic.BNReLU2d, - torch.nn.intrinsic.BNReLU3d]: - self.bn[1].activation_post_process = activation_post_process - else: - self.bn.activation_post_process = activation_post_process + self.bn.activation_post_process = quantizer.activation_post_process_map[node.name] qbn_cls = get_static_quant_module_class(type(self.bn), additional_static_quant_mapping) quantized = qbn_cls.from_float(self.bn) parent_name, name = _parent_name(self.bn_node.target) diff --git a/torch/quantization/quantization_mappings.py b/torch/quantization/quantization_mappings.py index 0aa5cb845e12..88d264b1ccf3 100644 --- a/torch/quantization/quantization_mappings.py +++ b/torch/quantization/quantization_mappings.py @@ -39,7 +39,6 @@ nn.LeakyReLU: nnq.LeakyReLU, nn.Linear: nnq.Linear, nn.ReLU6: nnq.ReLU6, - nn.ReLU: nnq.ReLU, # Wrapper Modules: nnq.FloatFunctional: nnq.QFunctional, # Intrinsic modules: diff --git a/torch/quantization/quantize.py b/torch/quantization/quantize.py index a7cc5b7b3893..f7471ee9fe9b 100644 --- a/torch/quantization/quantize.py +++ b/torch/quantization/quantize.py @@ -1,4 +1,3 @@ - import copy import itertools import warnings @@ -137,7 +136,8 @@ def insert_activation_post_process(m, special_act_post_process=None): m._forward_hooks.move_to_end(handle.id, last=False) for name, child in module.named_children(): - if type(child) == nnq.FloatFunctional or type(child) == nnq.QFunctional: + if type(child) in [nnq.FloatFunctional, nnq.QFunctional] or \ + isinstance(child, _FusedModule): if needs_observation(child): child.activation_post_process = get_activation_post_process(child.qconfig, device) elif _has_special_act_post_process(child): diff --git a/torch/quantization/quantize_fx.py b/torch/quantization/quantize_fx.py index 91d58c2966a4..202e9921432a 100644 --- a/torch/quantization/quantize_fx.py +++ b/torch/quantization/quantize_fx.py @@ -5,6 +5,7 @@ from .fx import Quantizer # noqa: F401 from .fx.utils import graph_pretty_str # noqa: F401 from .fx.utils import get_custom_module_class_keys # noqa: F401 +from torch.nn.intrinsic import _FusedModule def _check_is_graph_module(model): if not isinstance(model, GraphModule): @@ -47,7 +48,8 @@ def is_leaf_module(self, m, module_qualified_name): return (m.__module__.startswith('torch.nn') and not isinstance(m, torch.nn.Sequential)) or \ module_qualified_name in self.skipped_module_names or \ - type(m) in self.skipped_module_classes + type(m) in self.skipped_module_classes or \ + isinstance(m, _FusedModule) def _prepare_fx(model, qconfig_dict, prepare_custom_config_dict=None, is_standalone_module=False): diff --git a/torch/testing/_internal/common_quantization.py b/torch/testing/_internal/common_quantization.py index 1b2b4165b044..246577ab3d08 100644 --- a/torch/testing/_internal/common_quantization.py +++ b/torch/testing/_internal/common_quantization.py @@ -6,6 +6,7 @@ import torch.nn as nn import torch.nn.quantized as nnq import torch.nn.quantized.dynamic as nnqd +from torch.nn.intrinsic import _FusedModule import torch.distributed as dist from torch.testing._internal.common_utils import TestCase @@ -365,7 +366,8 @@ def is_leaf_module(module): # we don't need to check observers for child modules of the # qat modules if type(module) not in get_default_qat_module_mappings().values() and \ - type(module) not in float_to_observed_module_class_mapping.values(): + type(module) not in float_to_observed_module_class_mapping.values() and \ + not isinstance(module, _FusedModule): for child in module.children(): self.checkObservers(child, propagate_qconfig_list, prepare_custom_config_dict) From f42cdc2e43351f8fc77c7d399bcbd398215f5404 Mon Sep 17 00:00:00 2001 From: Nick Gibson Date: Thu, 12 Nov 2020 11:00:16 -0800 Subject: [PATCH 59/93] [NNC] Fix printing of integral doubles (#47799) Summary: When printing doubles, we don't do anything to distinguish intregal doubles (ie, 1 or 2) from ints. Added decoration of these doubles with `.0` if they are integral (i.e. DoubleImm(1) will print as `1.0`). This is an issue specifically on Cuda where some intrinsics do not have type coercion. Added a test which covers this case (without the fix it tries to look up pow(double, int) which doesn't exist). Fixes https://github.com/pytorch/pytorch/issues/47304 Pull Request resolved: https://github.com/pytorch/pytorch/pull/47799 Reviewed By: ZolotukhinM Differential Revision: D24904185 Pulled By: nickgg fbshipit-source-id: baa38726966c94ee50473cc046b9ded5c4e748f7 --- test/test_tensorexpr.py | 12 ++++++++++++ torch/csrc/jit/tensorexpr/ir_printer.cpp | 2 +- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/test/test_tensorexpr.py b/test/test_tensorexpr.py index 02a72a11c73e..b673e0233a0d 100644 --- a/test/test_tensorexpr.py +++ b/test/test_tensorexpr.py @@ -966,6 +966,18 @@ def test_min(x, y): assert np.isnan(warmup_and_run_forward(tmax, y, x).item()) self.assertLastGraphAllFused() + def test_double_intrinsics(self): + devices = ["cuda", "cpu"] if torch.cuda.is_available() else ["cpu"] + + def do_pow(x): + return torch.pow(x, 7) + + for device in devices: + x = torch.rand(10, dtype=torch.double, device=device) + traced = torch.jit.trace(do_pow, (x)) + x = warmup_and_run_forward(traced, x) + self.assertLastGraphAllFused() + def test_remainder(self): def run_remainder(x, y): c = torch.remainder(torch.add(x, y), x) diff --git a/torch/csrc/jit/tensorexpr/ir_printer.cpp b/torch/csrc/jit/tensorexpr/ir_printer.cpp index 730b170ca40b..ef8135c6887c 100644 --- a/torch/csrc/jit/tensorexpr/ir_printer.cpp +++ b/torch/csrc/jit/tensorexpr/ir_printer.cpp @@ -175,7 +175,7 @@ void IRPrinter::visit(const CompareSelect* v) { } static void formatFPSuffix(std::ostream& os, double v) { - // No suffix for doubles. + os << (v == std::ceil(v) ? ".0" : ""); } template From f221a19a7ffe57f5d46f4736c7507ae87b01079c Mon Sep 17 00:00:00 2001 From: Elias Ellison Date: Thu, 12 Nov 2020 11:06:50 -0800 Subject: [PATCH 60/93] Force LLVM Compilation for CPU Tests (#46949) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/46949 Test Plan: Imported from OSS Reviewed By: ansley Differential Revision: D24805247 Pulled By: eellison fbshipit-source-id: 4fcaf02d8a78cc5cbcbde36940d0a2c85fba3fc5 --- test/test_jit_fuser_te.py | 16 +++++----------- torch/csrc/jit/codegen/fuser/interface.h | 3 +++ torch/csrc/jit/python/init.cpp | 12 ++++++++++++ torch/csrc/jit/tensorexpr/kernel.cpp | 8 ++++++++ torch/csrc/jit/tensorexpr/kernel.h | 1 + 5 files changed, 29 insertions(+), 11 deletions(-) diff --git a/test/test_jit_fuser_te.py b/test/test_jit_fuser_te.py index 257aa1d0b143..4d5bb06588ac 100644 --- a/test/test_jit_fuser_te.py +++ b/test/test_jit_fuser_te.py @@ -15,7 +15,7 @@ torch._C._jit_set_profiling_executor(True) torch._C._jit_set_profiling_mode(True) -from torch.testing._internal.common_utils import run_tests, IS_SANDCASTLE, ProfilingMode, GRAPH_EXECUTOR, \ +from torch.testing._internal.common_utils import run_tests, ProfilingMode, GRAPH_EXECUTOR, \ enable_profiling_mode_for_profiling_tests from torch.testing._internal.jit_utils import JitTestCase, _inline_everything, \ RUN_CUDA, RUN_CUDA_HALF, RUN_CUDA_MULTI_GPU, warmup_backward @@ -52,9 +52,12 @@ def texpr_reductions_enabled(): class TestTEFuser(JitTestCase): def setUp(self): self.old_cpu_fuser_state = torch._C._jit_can_fuse_on_cpu() + self.old_must_use_cpu_state = torch._C._jit_get_te_must_use_llvm_cpu() self.old_gpu_fuser_state = torch._C._jit_can_fuse_on_gpu() torch._C._jit_override_can_fuse_on_cpu(True) + # TODO: force LLVM. need to add it to asan, mac, windows builds + sandcastle + # torch._C._jit_set_te_must_use_llvm_cpu(True) torch._C._jit_override_can_fuse_on_gpu(True) self.old_profiling_executor = torch._C._jit_set_profiling_executor(True) @@ -69,6 +72,7 @@ def tearDown(self): torch._C._jit_override_can_fuse_on_gpu(self.old_gpu_fuser_state) torch._C._jit_override_can_fuse_on_cpu(self.old_cpu_fuser_state) + torch._C._jit_set_te_must_use_llvm_cpu(self.old_must_use_cpu_state) torch._C._jit_set_texpr_fuser_enabled(self.texpr_fuser_state) @@ -93,7 +97,6 @@ def func(x): self.assertEqual(len(fusion_groups), 1) FileCheck().check("aten::abs").check("aten::mul").run(str(fusion_groups[0])) - @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle") def test_sum_simple(self): def func(x): x2 = x * x @@ -108,7 +111,6 @@ def func(x): self.assertEqual(len(fusion_groups), 1) self.assertEqual(scripted(a), func(a)) - @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle") def test_sum_dim(self): def func(x): return x.sum((0, )) * 2 @@ -122,7 +124,6 @@ def func(x): self.assertEqual(len(fusion_groups), 1) self.assertEqual(scripted(a), func(a)) - @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle") def test_sum_keepdim_cast(self): def func(x): return x.sum((0, ), keepdim=True, dtype=torch.double) * 2 @@ -136,7 +137,6 @@ def func(x): self.assertEqual(len(fusion_groups), 1) self.assertEqual(scripted(a), func(a)) - @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle") def test_abs_cpu(self): self._test_fused_abs() @@ -158,7 +158,6 @@ def decode(sin_t, cos_t): def test_zero_element_tensors_cuda(self): self._test_zero_element_tensors(device="cuda") - @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle") def test_zero_element_tensors_cpu(self): self._test_zero_element_tensors(device="cpu") @@ -289,7 +288,6 @@ def chunk_4_last(x): for fn in fns: self.checkScript(fn, [tensor]) - @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle") def test_chunk_correctness(self): return self._test_chunk_correctness(self, 'cpu') @@ -791,7 +789,6 @@ def fn_test_scalar_arg_requires_grad(x, p): self.assertAllFused(scripted.graph_for(x, p), except_for=("aten::size", "prim::BroadcastSizes", "aten::_size_if_not_equal")) - @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle") @unittest.skip("deduplicating introduces aliasing in backward graph's outputs") def test_fuser_deduplication(self): # See that fusion kernel outputs are deduplicated when removing _grad_sum_to_size in the fuser's compilation @@ -813,7 +810,6 @@ def f(x, y): # check that a, b share storage, i.e. were generated as a single output in the fuser self.assertEqual(ga2.data_ptr(), gb2.data_ptr()) - @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle") @unittest.skip("temporarily disabled because fusion was restricted in fixing #22833") def test_fuser_iou(self): # This checks if most of Intersection over Union is fused. @@ -978,7 +974,6 @@ def test_lstm_traced_cuda(self): self.assertEqual(len(fusion_groups), 1) FileCheck().check("Chunk").check("aten::sigmoid").check("aten::tanh").run(str(fusion_groups[0])) - @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle") @unittest.skip("Test is flaky, see https://github.com/pytorch/pytorch/issues/8746") def test_lstm_traced_cpu(self): inputs = get_lstm_inputs('cpu') @@ -1170,7 +1165,6 @@ def should_fuse_scalar(x, z): self.assertGraphContainsExactly( ge.graph_for(*inputs), FUSION_GROUP, 1, consider_subgraphs=True) - @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle") def test_where_and_typing(self): def f(x, y): mask = x > y diff --git a/torch/csrc/jit/codegen/fuser/interface.h b/torch/csrc/jit/codegen/fuser/interface.h index f3272cbc38cd..4d6220dc9ed6 100644 --- a/torch/csrc/jit/codegen/fuser/interface.h +++ b/torch/csrc/jit/codegen/fuser/interface.h @@ -32,6 +32,9 @@ TORCH_API bool canFuseOnGPU(); // flakiness) TORCH_API void overrideCanFuseOnCPU(bool value); +// Sets whether fusion on CPU must use LLVM Codegen and not SimplieIREval +TORCH_API void overrideMustUseLLVMOnCPU(bool value); + // Sets whether fusion on the GPU is allowed (enabled by default) TORCH_API void overrideCanFuseOnGPU(bool value); diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index 7e3ebd86e9e9..9cc2f07b2e6c 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -617,6 +617,18 @@ void initJITBindings(PyObject* module) { using namespace torch::jit::tensorexpr; return getTEGenerateBlockCode(); }) + .def( + "_jit_get_te_must_use_llvm_cpu", + []() -> bool { + using namespace torch::jit::tensorexpr; + return getTEMustUseLLVMOnCPU(); + }) + .def( + "_jit_set_te_must_use_llvm_cpu", + [](bool use_llvm) { + using namespace torch::jit::tensorexpr; + getTEMustUseLLVMOnCPU() = use_llvm; + }) .def( "_jit_pass_fuse_tensorexprs", [](std::shared_ptr& g) { return FuseTensorExprs(g); }) diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp index 97d3bf7ce9d8..3c591ffef3e1 100644 --- a/torch/csrc/jit/tensorexpr/kernel.cpp +++ b/torch/csrc/jit/tensorexpr/kernel.cpp @@ -20,6 +20,7 @@ static int te_cuda_pointwise_block_count = -1; static int te_cuda_pointwise_block_size = -1; static bool fallback_allowed = false; static bool te_generate_block_code = false; +static bool te_must_use_llvm_on_cpu = false; bool setFallbackAllowed(bool value) { bool old_value = fallback_allowed; @@ -68,6 +69,10 @@ bool& getTEGenerateBlockCode() { return te_generate_block_code; } +bool& getTEMustUseLLVMOnCPU() { + return te_must_use_llvm_on_cpu; +} + c10::optional pickDeviceType( const at::ArrayRef& inputs) { c10::optional device = c10::nullopt; @@ -1547,6 +1552,9 @@ TensorExprKernel::BackendType TensorExprKernel::inferBackendTypeFromDevice( #else backendType = kSimpleIREval; #endif + if (getTEMustUseLLVMOnCPU() && backendType == kSimpleIREval) { + throw std::runtime_error("LLVM Backend not found"); + } } else { throw std::runtime_error("Invalid device type"); } diff --git a/torch/csrc/jit/tensorexpr/kernel.h b/torch/csrc/jit/tensorexpr/kernel.h index 91452e411a6e..c6df7f4c8469 100644 --- a/torch/csrc/jit/tensorexpr/kernel.h +++ b/torch/csrc/jit/tensorexpr/kernel.h @@ -209,6 +209,7 @@ TORCH_API int& getTECudaPointwiseLoopLevels(); TORCH_API int& getTECudaPointwiseBlockCount(); TORCH_API int& getTECudaPointwiseBlockSize(); TORCH_API bool& getTEGenerateBlockCode(); +TORCH_API bool& getTEMustUseLLVMOnCPU(); TORCH_API bool fallbackAllowed(); TORCH_API bool setFallbackAllowed(bool value); From ad5be26b2fee9303a93ec921be450a40504969ed Mon Sep 17 00:00:00 2001 From: Elias Ellison Date: Thu, 12 Nov 2020 11:06:50 -0800 Subject: [PATCH 61/93] Small changes/cleanup (#46950) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/46950 Make sure that we're fusing in a fuse tests, and refactor to more concise API to check if fusions have happened. Test Plan: Imported from OSS Reviewed By: ansley Differential Revision: D24805250 Pulled By: eellison fbshipit-source-id: f898008a64b74e761bb5fe85f91b3cdf2dbdf878 --- test/test_jit_fuser_te.py | 47 +++++++++++++-------------------------- 1 file changed, 16 insertions(+), 31 deletions(-) diff --git a/test/test_jit_fuser_te.py b/test/test_jit_fuser_te.py index 4d5bb06588ac..324839a636f5 100644 --- a/test/test_jit_fuser_te.py +++ b/test/test_jit_fuser_te.py @@ -76,6 +76,9 @@ def tearDown(self): torch._C._jit_set_texpr_fuser_enabled(self.texpr_fuser_state) + def assertLastGraphAllFused(self): + self.assertAllFused(torch.jit.last_executed_optimized_graph()) + def findFusionGroups(self, graph): result = [] for n in graph.nodes(): @@ -92,10 +95,7 @@ def func(x): a = torch.randn(5, device=device) scripted = self.checkScript(func, (a,)) - graph = scripted.graph_for(a) - fusion_groups = self.findFusionGroups(graph) - self.assertEqual(len(fusion_groups), 1) - FileCheck().check("aten::abs").check("aten::mul").run(str(fusion_groups[0])) + self.assertLastGraphAllFused() def test_sum_simple(self): def func(x): @@ -106,10 +106,7 @@ def func(x): a = torch.tensor(list(x for x in range(0, 15)), dtype=torch.float, device='cpu') a = a.reshape(5, 3) scripted = self.checkScript(func, (a,)) - graph = scripted.graph_for(a) - fusion_groups = self.findFusionGroups(graph) - self.assertEqual(len(fusion_groups), 1) - self.assertEqual(scripted(a), func(a)) + self.assertLastGraphAllFused() def test_sum_dim(self): def func(x): @@ -119,10 +116,7 @@ def func(x): a = torch.tensor(list(x for x in range(0, 15)), dtype=torch.float, device='cpu') a = a.reshape(5, 3) scripted = self.checkScript(func, (a,)) - graph = scripted.graph_for(a) - fusion_groups = self.findFusionGroups(graph) - self.assertEqual(len(fusion_groups), 1) - self.assertEqual(scripted(a), func(a)) + self.assertLastGraphAllFused() def test_sum_keepdim_cast(self): def func(x): @@ -131,11 +125,9 @@ def func(x): with texpr_reductions_enabled(): a = torch.tensor(list(x for x in range(0, 15)), dtype=torch.float, device='cpu') a = a.reshape(5, 3) - scripted = self.checkScript(func, (a,)) - graph = scripted.graph_for(a) - fusion_groups = self.findFusionGroups(graph) - self.assertEqual(len(fusion_groups), 1) - self.assertEqual(scripted(a), func(a)) + + self.checkScript(func, (a,)) + self.assertLastGraphAllFused() def test_abs_cpu(self): self._test_fused_abs() @@ -186,11 +178,8 @@ def scaleshift(x, scale, shift): torch.randn(4, dtype=torch.float, device='cuda'), torch.randn(4, dtype=torch.float, device='cuda'), ] - ge = self.checkTrace(scaleshift, inputs) - graph = ge.graph_for(*inputs) - fusion_groups = self.findFusionGroups(graph) - self.assertEqual(len(fusion_groups), 1) - FileCheck().check("aten::mul").check("aten::add").run(str(fusion_groups[0])) + self.checkScript(scaleshift, inputs) + self.assertLastGraphAllFused() @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") @unittest.skipIf(not RUN_CUDA_HALF, "no half support") @@ -253,10 +242,8 @@ def fn(x): inputs = [torch.randn(10, 6, dtype=torch.float, device='cuda')] - ge = self.checkScript(fn, inputs) - graph = ge.graph_for(*inputs) - self.assertAllFused(graph) - FileCheck().check("prim::ConstantChunk[chunks=3, dim=1]").run(str(graph)) + self.checkScript(fn, inputs) + self.assertLastGraphAllFused() @staticmethod def _test_chunk_correctness(self, device='cpu'): @@ -287,6 +274,7 @@ def chunk_4_last(x): for tensor in tensors: for fn in fns: self.checkScript(fn, [tensor]) + self.assertLastGraphAllFused() def test_chunk_correctness(self): return self._test_chunk_correctness(self, 'cpu') @@ -329,11 +317,8 @@ def func2(x): torch.tensor([1.1, 1.2], device='cuda', dtype=torch.float), ] for func in [func1, func2]: - module = self.checkScript(func, inputs) - forward_graph = module.graph_for(*inputs) - self.assertGraphContainsExactly(forward_graph, FUSION_GROUP, 1) - fusion_group = list(forward_graph.nodes())[-1] - self.assertEqual(len(list(fusion_group.inputs())), 1) + self.checkScript(func, inputs) + self.assertLastGraphAllFused() @unittest.skipIf(not RUN_CUDA, "No CUDA") def test_chunk_multiple_cuda(self): From b8a1070ec0b20ae634ff4c84faf4bf21a1bc6d6b Mon Sep 17 00:00:00 2001 From: Elias Ellison Date: Thu, 12 Nov 2020 11:06:50 -0800 Subject: [PATCH 62/93] [TensorExpr][CPU] Fix bool -> int casting (#46951) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/46951 If e.g. we're casting from torch.int -> torch.bool, previously we would just truncate from int32 -> i8. Since torch.bool has 8 bits but only uses one of them, we need to makes sure that one bit is set. Test Plan: Imported from OSS Reviewed By: ansley Differential Revision: D24805253 Pulled By: eellison fbshipit-source-id: af3aa323f10820d189827eb51037adfa7d80fed9 --- test/test_jit_fuser_te.py | 22 +++++----- torch/csrc/jit/tensorexpr/llvm_codegen.cpp | 48 +++++++++++++++------- 2 files changed, 45 insertions(+), 25 deletions(-) diff --git a/test/test_jit_fuser_te.py b/test/test_jit_fuser_te.py index 324839a636f5..7392db3a757b 100644 --- a/test/test_jit_fuser_te.py +++ b/test/test_jit_fuser_te.py @@ -66,6 +66,8 @@ def setUp(self): self.texpr_fuser_state = torch._C._jit_texpr_fuser_enabled() torch._C._jit_set_texpr_fuser_enabled(True) + self.devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda'] + def tearDown(self): torch._C._jit_set_profiling_executor(self.old_profiling_executor) torch._C._jit_set_profiling_mode(self.old_profiling_mode) @@ -411,17 +413,17 @@ def func(x): graph = backward_graph(s, skip_check=True) self.assertAllFused(graph, except_for={'aten::div', 'prim::Constant'}) - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") def test_add_bool(self): - def f(x, y, z): - return x + y + z - - x = torch.randint(0, 2, (4, 4), dtype=torch.bool, device='cuda') - y = torch.randint(0, 2, (4, 4), dtype=torch.bool, device='cuda') - z = torch.randint(0, 2, (4, 4), dtype=torch.bool, device='cuda') - - ge = self.checkTrace(f, (x, y, z), inputs_require_grads=False) - self.assertAllFused(ge.graph_for(x, y, z)) + sizes = [(1,), (2,), (4, 4)] + for device, size in product(self.devices, sizes): + def f(x, y, z): + return x + y + z + + x = torch.randint(0, 2, size, dtype=torch.bool, device=device) + y = torch.randint(0, 2, size, dtype=torch.bool, device=device) + z = torch.randint(0, 2, size, dtype=torch.bool, device=device) + ge = self.checkTrace(f, (x, y, z), inputs_require_grads=False) + self.assertAllFused(ge.graph_for(x, y, z)) @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") def test_mul_bool(self): diff --git a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp index 8799bc4ff051..0cc7b1f65852 100644 --- a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp +++ b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp @@ -124,6 +124,7 @@ class LLVMCodeGenImpl : public IRVisitor { llvm::Type* dtypeToLLVMPtr(Dtype dtype); void emitWrapper(const std::vector& params); void emitKernel(Stmt* stmt, const std::vector& params); + llvm::Value* toVec(llvm::Value* v, int lanes); public: LLVMCodeGenImpl( @@ -826,7 +827,8 @@ void LLVMCodeGenImpl::visit(const Cast* v) { return; } - bool destUnsigned = v->dtype().scalar_type() == ScalarType::Byte; + bool destUnsigned = v->dtype().scalar_type() == ScalarType::Byte || + v->dtype().scalar_type() == ScalarType::Bool; // Scalar casts if (srcType->isFPOrFPVectorTy()) { @@ -841,18 +843,28 @@ void LLVMCodeGenImpl::visit(const Cast* v) { } else { throw unimplemented_lowering(v); } - } else if (srcType->isIntOrIntVectorTy()) { - if (dstType->isFPOrFPVectorTy()) { - if (destUnsigned) { - value_ = irb_.CreateUIToFP(value_, dstType); - } else { - value_ = irb_.CreateSIToFP(value_, dstType); - } - } else if (dstType->isIntOrIntVectorTy()) { - value_ = irb_.CreateIntCast(value_, dstType, !destUnsigned); + return; + } + if (!srcType->isIntOrIntVectorTy()) { + throw unimplemented_lowering(v); + } + if (dstType->isFPOrFPVectorTy()) { + if (destUnsigned) { + value_ = irb_.CreateUIToFP(value_, dstType); } else { - throw unimplemented_lowering(v); + value_ = irb_.CreateSIToFP(value_, dstType); } + } else if (dstType->isIntOrIntVectorTy()) { + // Ensure bool true value is exactly one, since we convert to int + // from bool by zero extending the int8 + if (v->dtype().scalar_type() == ScalarType::Bool) { + llvm::Value* zero = + toVec(llvm::ConstantInt::get(srcType, 0), v->dtype().lanes()); + value_ = irb_.CreateICmpNE(value_, zero); + } + value_ = irb_.CreateIntCast(value_, dstType, !destUnsigned); + } else { + throw unimplemented_lowering(v); } } @@ -1287,6 +1299,14 @@ struct FunctionCallee { } // namespace +llvm::Value* LLVMCodeGenImpl::toVec(llvm::Value* v, int lanes) { + if (lanes > 1) { + return irb_.CreateVectorSplat(lanes, v); + } else { + return v; + } +} + void LLVMCodeGenImpl::visit(const Intrinsics* v) { llvm::FunctionType* call_ty = nullptr; llvm::Value* call_fn = nullptr; @@ -1297,10 +1317,8 @@ void LLVMCodeGenImpl::visit(const Intrinsics* v) { case kRsqrt: { v->params().front()->accept(this); value_ = irb_.CreateUnaryIntrinsic(llvm::Intrinsic::sqrt, value_); - llvm::Value* constant = llvm::ConstantFP::get(FloatTy_, 1.0); - if (v->dtype().lanes() > 1) { - constant = irb_.CreateVectorSplat(v->dtype().lanes(), constant); - } + llvm::Value* constant = + toVec(llvm::ConstantFP::get(FloatTy_, 1.0), v->dtype().lanes()); value_ = irb_.CreateFDiv(constant, value_); return; } break; From fe81faee5f65dfd3c015c7337729f70b51bba10e Mon Sep 17 00:00:00 2001 From: Elias Ellison Date: Thu, 12 Nov 2020 11:06:50 -0800 Subject: [PATCH 63/93] Add more CPU tests (#47369) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47369 Test Plan: Imported from OSS Reviewed By: ansley Differential Revision: D24805251 Pulled By: eellison fbshipit-source-id: f1a8210ffdc3cc88354cb4896652151d83a0345a --- test/test_jit_fuser_te.py | 372 ++++++++++++++++----------------- torch/_C/__init__.pyi.in | 1 + torch/csrc/jit/python/init.cpp | 7 + 3 files changed, 192 insertions(+), 188 deletions(-) diff --git a/test/test_jit_fuser_te.py b/test/test_jit_fuser_te.py index 7392db3a757b..bbefd1fc3ab8 100644 --- a/test/test_jit_fuser_te.py +++ b/test/test_jit_fuser_te.py @@ -170,18 +170,17 @@ def f(x, y): traced_f = torch.jit.trace(f, (x, y,)) self.assertEqual(traced_f(x.t().contiguous(), y), traced_f(x.t(), y)) - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - def test_broadcast_cuda(self): - def scaleshift(x, scale, shift): - return x * scale + shift - - inputs = [ - torch.randn(4, 4, dtype=torch.float, device='cuda'), - torch.randn(4, dtype=torch.float, device='cuda'), - torch.randn(4, dtype=torch.float, device='cuda'), - ] - self.checkScript(scaleshift, inputs) - self.assertLastGraphAllFused() + def test_broadcast(self): + for device in self.devices: + def scaleshift(x, scale, shift): + return x * scale + shift + + inputs = [ + torch.randn(4, 4, dtype=torch.float, device=device), + torch.randn(4, dtype=torch.float, device=device), + torch.randn(4, dtype=torch.float, device=device), + ] + self.checkScript(scaleshift, inputs) @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") @unittest.skipIf(not RUN_CUDA_HALF, "no half support") @@ -219,33 +218,33 @@ def test_cuda_half(self): grads_half = [t.half() for t in grads] self.assertEqual(grads_half, fusion_grads) - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") def test_checks_cat_inputs(self): - # We shouldn't treat cat nodes as broadcasting. All their inputs - # need to be checked for having the same map size, before we can - # run the kernel. - def f(x, y): - return torch.cat([x + 2 * x + x ** 2, y + 4 * y + y ** 3], dim=0) - - # NOTE: y is broadcastable to x, but output of f(x, y) should have - # shape 3x4, and not 4x4. - x = torch.randn(2, 4, dtype=torch.float, device='cuda') - y = torch.randn(1, 4, dtype=torch.float, device='cuda') - - scripted = self.checkScript(f, (x, y)) - self.assertEqual(scripted(x, y).shape, (3, 4)) - self.assertAllFused(scripted.graph_for(x, y)) - - @unittest.skipIf(not RUN_CUDA, "No CUDA") - def test_chunk_cuda(self): - def fn(x): - a, b, c = x.chunk(3, 1) - return a * b + c - - inputs = [torch.randn(10, 6, dtype=torch.float, device='cuda')] - - self.checkScript(fn, inputs) - self.assertLastGraphAllFused() + for device in self.devices: + # We shouldn't treat cat nodes as broadcasting. All their inputs + # need to be checked for having the same map size, before we can + # run the kernel. + def f(x, y): + return torch.cat([x + 2 * x + x ** 2, y + 4 * y + y ** 3], dim=0) + + # NOTE: y is broadcastable to x, but output of f(x, y) should have + # shape 3x4, and not 4x4. + x = torch.randn(2, 4, dtype=torch.float, device=device) + y = torch.randn(1, 4, dtype=torch.float, device=device) + + scripted = self.checkScript(f, (x, y)) + self.assertEqual(scripted(x, y).shape, (3, 4)) + self.assertAllFused(scripted.graph_for(x, y)) + + def test_chunk(self): + for device in self.devices: + def fn(x): + a, b, c = x.chunk(3, 1) + return a * b + c + + inputs = [torch.randn(10, 6, dtype=torch.float, device=device)] + + self.checkScript(fn, inputs) + self.assertLastGraphAllFused() @staticmethod def _test_chunk_correctness(self, device='cpu'): @@ -303,99 +302,96 @@ def f(x, y): "ConstantChunk", 1, exactly=True ).run(str(graph)) - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") def test_chunk_motion_deduplicates_inputs(self): - def func1(x): - z = x * x - z0, z1 = z.chunk(2) - return z0 * z1 + for device in self.devices: + def func1(x): + z = x * x + z0, z1 = z.chunk(2) + return z0 * z1 + + def func2(x): + z = x * x * x + z0, z1 = z.chunk(2) + return z0 * z1 + + inputs = [ + torch.tensor([1.1, 1.2], device=device, dtype=torch.float), + ] + for func in [func1, func2]: + self.checkScript(func, inputs) + self.assertLastGraphAllFused() - def func2(x): - z = x * x * x - z0, z1 = z.chunk(2) - return z0 * z1 + def test_chunk_multiple(self): + for device in self.devices: + # The arguments are intentionally used out of order as a test to see + # if the fusion compiler adds extra args in the correct order + def fn(s, x, y, z): + z1, z2 = z.chunk(2, 2) + x1, x2, x3 = x.chunk(3, 1) + y1, y2 = y.chunk(2, 0) + return s + x1 + x2 + x3 + y1 + y2 + z1 + z2 + + inputs = [ + torch.randn(5, 2, 3, dtype=torch.float, device=device), + torch.randn(5, 6, 3, dtype=torch.float, device=device), + torch.randn(10, 2, 3, dtype=torch.float, device=device), + torch.randn(5, 2, 6, dtype=torch.float, device=device), + ] + + ge = self.checkScript(fn, inputs) + self.assertAllFused(ge.graph_for(*inputs)) - inputs = [ - torch.tensor([1.1, 1.2], device='cuda', dtype=torch.float), - ] - for func in [func1, func2]: - self.checkScript(func, inputs) - self.assertLastGraphAllFused() + def test_minmax(self): + for device in self.devices: + def tmax(a, b): + return torch.max(2 * a, b) - @unittest.skipIf(not RUN_CUDA, "No CUDA") - def test_chunk_multiple_cuda(self): - # The arguments are intentionally used out of order as a test to see - # if the fusion compiler adds extra args in the correct order - def fn(s, x, y, z): - z1, z2 = z.chunk(2, 2) - x1, x2, x3 = x.chunk(3, 1) - y1, y2 = y.chunk(2, 0) - return s + x1 + x2 + x3 + y1 + y2 + z1 + z2 + def tmin(a, b): + return torch.min(2 * a, b) - inputs = [ - torch.randn(5, 2, 3, dtype=torch.float, device='cuda'), - torch.randn(5, 6, 3, dtype=torch.float, device='cuda'), - torch.randn(10, 2, 3, dtype=torch.float, device='cuda'), - torch.randn(5, 2, 6, dtype=torch.float, device='cuda'), - ] + a = torch.randn(4, 4, dtype=torch.float) + b = torch.randn(4, 4, dtype=torch.float) + nan = torch.tensor(float('nan'), dtype=torch.float) - ge = self.checkScript(fn, inputs) - self.assertAllFused(ge.graph_for(*inputs)) + for f, inputs, device in product( + (tmax, tmin), + ([a, b], [a, nan], [b, nan]), + self.devices): + inputs = [t.to(device) for t in inputs] + s = self.checkScript(f, inputs) + self.assertAllFused(s.graph_for(*inputs)) - def test_minmax(self): - def tmax(a, b): - return torch.max(2 * a, b) - - def tmin(a, b): - return torch.min(2 * a, b) - - a = torch.randn(4, 4, dtype=torch.float) - b = torch.randn(4, 4, dtype=torch.float) - nan = torch.tensor(float('nan'), dtype=torch.float) - - devices = ["cpu"] - if torch.cuda.is_available(): - devices.append("cuda") - for f, inputs, device in product( - (tmax, tmin), - ([a, b], [a, nan], [b, nan]), - devices): - inputs = [t.to(device) for t in inputs] - s = self.checkScript(f, inputs) - self.assertAllFused(s.graph_for(*inputs)) - - # TODO: reenable the test after backwards passes start working in PE - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") def test_clamp(self): - def func2(a, b): - return torch.clamp(a + b, min=0, max=2) + for device in self.devices: + def func2(a, b): + return torch.clamp(a + b, min=0, max=2) - def funcInf(a, b): - return torch.clamp(a + b, min=0, max=float('inf')) + def funcInf(a, b): + return torch.clamp(a + b, min=0, max=float('inf')) - def funcNegInf(a, b): - return torch.clamp(a + b, min=float('-inf'), max=0) + def funcNegInf(a, b): + return torch.clamp(a + b, min=float('-inf'), max=0) - def funcOptMin(a, b): - return torch.clamp(a + b, max=2) + def funcOptMin(a, b): + return torch.clamp(a + b, max=2) - def funcOptMax(a, b): - return torch.clamp(a + b, min=0) + def funcOptMax(a, b): + return torch.clamp(a + b, min=0) - a = torch.randn(4, 4, dtype=torch.float, device='cuda', requires_grad=True) - b = torch.randn(4, 4, dtype=torch.float, device='cuda') - nan = torch.tensor(float('nan'), dtype=torch.float, device='cuda') - - funcs = (func2, funcInf, funcNegInf, funcOptMin, funcOptMax) - for f, inputs in product(funcs, [[a, b], [a, nan]]): - inp1, inp2 = inputs - s = self.checkScript(f, (inp1, inp2), profiling=ProfilingMode.PROFILING) - self.assertAllFused(s.graph_for(inp1, inp2), except_for={'aten::size', 'aten::_size_if_not_equal'}) - c = s(inp1, inp2) - with enable_profiling_mode_for_profiling_tests(): - warmup_backward(c.sum()) - graph = backward_graph(s) - self.assertAllFused(graph, except_for={'aten::Float', 'aten::_grad_sum_to_size'}) + a = torch.randn(4, 4, dtype=torch.float, device=device, requires_grad=True) + b = torch.randn(4, 4, dtype=torch.float, device=device) + nan = torch.tensor(float('nan'), dtype=torch.float, device=device) + + funcs = (func2, funcInf, funcNegInf, funcOptMin, funcOptMax) + for f, inputs in product(funcs, [[a, b], [a, nan]]): + inp1, inp2 = inputs + s = self.checkScript(f, (inp1, inp2), profiling=ProfilingMode.PROFILING) + self.assertAllFused(s.graph_for(inp1, inp2), except_for={'aten::size', 'aten::_size_if_not_equal'}) + c = s(inp1, inp2) + with enable_profiling_mode_for_profiling_tests(): + warmup_backward(c.sum()) + graph = backward_graph(s) + self.assertAllFused(graph, except_for={'aten::Float', 'aten::_grad_sum_to_size'}) @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "no half support with profiling on") @@ -425,31 +421,31 @@ def f(x, y, z): ge = self.checkTrace(f, (x, y, z), inputs_require_grads=False) self.assertAllFused(ge.graph_for(x, y, z)) - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") def test_mul_bool(self): - def f(x, y, z): - return x * y * z + for device in self.devices: + def f(x, y, z): + return x * y * z - x = torch.randint(0, 2, (4, 4), dtype=torch.bool, device='cuda') - y = torch.randint(0, 2, (4, 4), dtype=torch.bool, device='cuda') - z = torch.randint(0, 2, (4, 4), dtype=torch.bool, device='cuda') + x = torch.randint(0, 2, (4, 4), dtype=torch.bool, device=device) + y = torch.randint(0, 2, (4, 4), dtype=torch.bool, device=device) + z = torch.randint(0, 2, (4, 4), dtype=torch.bool, device=device) - ge = self.checkTrace(f, (x, y, z), inputs_require_grads=False) - self.assertAllFused(ge.graph_for(x, y, z)) + ge = self.checkTrace(f, (x, y, z), inputs_require_grads=False) + self.assertAllFused(ge.graph_for(x, y, z)) - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") def test_div_bool(self): - def f(x, y, z): - return (x + y) / z + for device in self.devices: + def f(x, y, z): + return (x + y) / z - x = torch.randint(0, 2, (4, 4), dtype=torch.bool, device='cuda') - y = torch.randint(0, 2, (4, 4), dtype=torch.bool, device='cuda') - z = torch.ones_like(x, dtype=torch.bool, device='cuda') + x = torch.randint(0, 2, (4, 4), dtype=torch.bool, device=device) + y = torch.randint(0, 2, (4, 4), dtype=torch.bool, device=device) + z = torch.ones_like(x, dtype=torch.bool, device=device) - ge = self.checkTrace(f, (x, y, z), inputs_require_grads=False) - self.assertAllFused(ge.graph_for(x, y, z)) + ge = self.checkTrace(f, (x, y, z), inputs_require_grads=False) + self.assertAllFused(ge.graph_for(x, y, z)) - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") + @unittest.skipIf(not torch._C._llvm_enabled(), "TODO: bugs in ir eval") def test_bitwise_ops(self): def apply(fn): return lambda x, y, z: fn(fn(x, y), z) @@ -467,7 +463,7 @@ def apply(fn): operator.__or__, operator.__xor__ ] - devices = ["cuda"] + devices = self.devices for dtype, op, device in product(dtypes, binary_ops, devices): try: x = self.data_for(dtype, device) @@ -528,20 +524,20 @@ def apply(fn): " ".join(["Failed:", str(dtype), op.__name__, device]) ) - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") def test_comparison_eq_ne(self): - def f(x, y): - mask = (x == 0).type_as(x) - z = x * mask + y - mask = (x != 0).type_as(x) - z = z * mask + y - return z + for device in self.devices: + def f(x, y): + mask = (x == 0).type_as(x) + z = x * mask + y + mask = (x != 0).type_as(x) + z = z * mask + y + return z - x = torch.randn(4, 4, dtype=torch.float, device='cuda') - y = torch.randn(4, 4, dtype=torch.float, device='cuda') + x = torch.randn(4, 4, dtype=torch.float, device=device) + y = torch.randn(4, 4, dtype=torch.float, device=device) - ge = self.checkTrace(f, (x, y)) - self.assertAllFused(ge.graph_for(x, y)) + ge = self.checkTrace(f, (x, y)) + self.assertAllFused(ge.graph_for(x, y)) @staticmethod def fn_test_comparison_gt_lt(x, y): @@ -551,47 +547,47 @@ def fn_test_comparison_gt_lt(x, y): z = z * mask + y return z - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - def test_comparison_gt_lt_cuda(self): - x = torch.randn(4, 4, dtype=torch.float, device='cuda') - y = torch.randn(4, 4, dtype=torch.float, device='cuda') - - ge = self.checkTrace(self.fn_test_comparison_gt_lt, (x, y)) - self.assertAllFused(ge.graph_for(x, y)) - - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - def test_comparison_ge_le_cuda(self): - def f(x, y): - mask = (x >= 0).type_as(x) - z = x * mask + y - mask = (x <= 0).type_as(x) - z = z * mask + y - return z - - x = torch.randn(4, 4, dtype=torch.float, device='cuda') - y = torch.randn(4, 4, dtype=torch.float, device='cuda') - - ge = self.checkTrace(f, (x, y)) - self.assertAllFused(ge.graph_for(x, y)) - x.requires_grad_(True) - y.requires_grad_(True) - self.assertAllFused(ge.graph_for(x, y), except_for=("aten::size", "prim::BroadcastSizes", - "aten::_size_if_not_equal")) - - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - def test_addcmul_cuda(self): - t = torch.randn(1, 4, dtype=torch.float, device='cuda') - t1 = torch.randn(4, 1, dtype=torch.float, device='cuda') - t2 = torch.randn(1, 4, dtype=torch.float, device='cuda') - - def foo(t, t1, t2): - return t.addcmul(t + 1, t2, value=0.1) - - ge = self.checkTrace(foo, (t, t1, t2), allow_unused=True) - graph = ge.graph_for(t, t1, t2) - fusion_groups = self.findFusionGroups(graph) - self.assertEqual(len(fusion_groups), 1) - FileCheck().check("aten::add(").check("aten::addcmul(").run(str(fusion_groups[0])) + def test_comparison_gt_lt(self): + for device in self.devices: + x = torch.randn(4, 4, dtype=torch.float, device=device) + y = torch.randn(4, 4, dtype=torch.float, device=device) + + ge = self.checkTrace(self.fn_test_comparison_gt_lt, (x, y)) + self.assertAllFused(ge.graph_for(x, y)) + + def test_comparison_ge_le(self): + for device in self.devices: + def f(x, y): + mask = (x >= 0).type_as(x) + z = x * mask + y + mask = (x <= 0).type_as(x) + z = z * mask + y + return z + + x = torch.randn(4, 4, dtype=torch.float, device=device) + y = torch.randn(4, 4, dtype=torch.float, device=device) + + ge = self.checkTrace(f, (x, y)) + self.assertAllFused(ge.graph_for(x, y)) + x.requires_grad_(True) + y.requires_grad_(True) + self.assertAllFused(ge.graph_for(x, y), except_for=("aten::size", "prim::BroadcastSizes", + "aten::_size_if_not_equal")) + + def test_addcmul(self): + for device in self.devices: + t = torch.randn(1, 4, dtype=torch.float, device=device) + t1 = torch.randn(4, 1, dtype=torch.float, device=device) + t2 = torch.randn(1, 4, dtype=torch.float, device=device) + + def foo(t, t1, t2): + return t.addcmul(t + 1, t2, value=0.1) + + ge = self.checkTrace(foo, (t, t1, t2), allow_unused=True) + graph = ge.graph_for(t, t1, t2) + fusion_groups = self.findFusionGroups(graph) + self.assertEqual(len(fusion_groups), 1) + FileCheck().check("aten::add(").check("aten::addcmul(").run(str(fusion_groups[0])) # TODO: We leak CUDA memory here because the traced graph holds onto a # constant-ified tensor. Since the Python-global CompilationUnit is alive diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index bdebb355e33f..a6af57152b9a 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -184,6 +184,7 @@ def _jit_can_fuse_on_cpu() -> _bool: ... def _jit_can_fuse_on_gpu() -> _bool: ... def _jit_texpr_fuser_enabled() -> _bool: ... def _jit_nvfuser_enabled() -> _bool: ... +def _llvm_enabled() -> _bool: ... def _jit_override_can_fuse_on_cpu(override: _bool): ... def _jit_override_can_fuse_on_gpu(override: _bool): ... def _jit_set_texpr_fuser_enabled(enable: _bool): ... diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index 9cc2f07b2e6c..c2de6ec9292e 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -629,6 +629,13 @@ void initJITBindings(PyObject* module) { using namespace torch::jit::tensorexpr; getTEMustUseLLVMOnCPU() = use_llvm; }) + .def("_llvm_enabled", []() { + #ifdef TORCH_ENABLE_LLVM + return true; + #else + return false; + #endif + }) .def( "_jit_pass_fuse_tensorexprs", [](std::shared_ptr& g) { return FuseTensorExprs(g); }) From e618bd858e69587e593a723212ed68c366d477ae Mon Sep 17 00:00:00 2001 From: Elias Ellison Date: Thu, 12 Nov 2020 11:06:50 -0800 Subject: [PATCH 64/93] [NNC] Fix llvm min lowering for int inputs (#47370) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47370 Test Plan: Imported from OSS Reviewed By: ansley Differential Revision: D24805249 Pulled By: eellison fbshipit-source-id: e13d956899e8651600fab94dab04aa39ca427769 --- test/test_jit_fuser_te.py | 3 +-- torch/csrc/jit/tensorexpr/llvm_codegen.cpp | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/test/test_jit_fuser_te.py b/test/test_jit_fuser_te.py index bbefd1fc3ab8..051066d4d8ae 100644 --- a/test/test_jit_fuser_te.py +++ b/test/test_jit_fuser_te.py @@ -485,7 +485,6 @@ def apply(fn): " ".join(["Failed:", str(dtype), op.__name__, device]) ) - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") def test_minmax_int_ops(self): def apply(fn): return lambda x, y, z: fn(fn(x, y), z) @@ -502,7 +501,7 @@ def apply(fn): torch.min, torch.max ] - devices = ["cuda"] + devices = self.devices for dtype, op, device in product(dtypes, binary_ops, devices): try: x = self.data_for(dtype, device) diff --git a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp index 0cc7b1f65852..98c4df86ea4c 100644 --- a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp +++ b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp @@ -719,8 +719,7 @@ void LLVMCodeGenImpl::visit(const Min* v) { auto lhs = this->value_; v->rhs()->accept(this); auto rhs = this->value_; - - if (v->dtype() == kInt) { + if (v->dtype().is_integral()) { auto icmp = irb_.CreateICmpSLT(lhs, rhs); value_ = irb_.CreateSelect(icmp, lhs, rhs); return; From 450738441b6cb28c7f0f3ed65fe924e2399bd73c Mon Sep 17 00:00:00 2001 From: Elias Ellison Date: Thu, 12 Nov 2020 11:06:50 -0800 Subject: [PATCH 65/93] [NNC] Add more CPU Tests (#47371) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47371 Test Plan: Imported from OSS Reviewed By: ansley Differential Revision: D24805252 Pulled By: eellison fbshipit-source-id: 16472960d09f6c981adca2a45b2a4efb75a09d4f --- test/test_jit_fuser_te.py | 143 +++++++++++++++++--------------------- 1 file changed, 64 insertions(+), 79 deletions(-) diff --git a/test/test_jit_fuser_te.py b/test/test_jit_fuser_te.py index 051066d4d8ae..233399e63cf4 100644 --- a/test/test_jit_fuser_te.py +++ b/test/test_jit_fuser_te.py @@ -594,42 +594,42 @@ def foo(t, t1, t2): # Removed `_cuda` suffix from this test which disables leak-checking. # If this is a real problem, we'll need to revisit Torchscript Function # lifetimes in Python. - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") def test_lerp(self): - start = torch.randn(4, 1, dtype=torch.float, device='cuda') - end = torch.randn(1, 4, dtype=torch.float, device='cuda') - weight = torch.tensor(0.5, dtype=torch.float, device='cuda') + for device in self.devices: + start = torch.randn(4, 1, dtype=torch.float, device=device) + end = torch.randn(1, 4, dtype=torch.float, device=device) + weight = torch.tensor(0.5, dtype=torch.float, device=device) - # scalar weight overload - def foo_weight_scalar(start, end): - return torch.lerp(start + 1, end, 0.5) + # scalar weight overload + def foo_weight_scalar(start, end): + return torch.lerp(start + 1, end, 0.5) - # tensor weight overload - def foo_weight_tensor(start, end): - return torch.lerp(start + 1, end, weight) + # tensor weight overload + def foo_weight_tensor(start, end): + return torch.lerp(start + 1, end, weight) - ge_weight_scalar = self.checkTrace(foo_weight_scalar, (start, end)) - graph = ge_weight_scalar.graph_for(start, end) - self.assertAllFused(graph) + ge_weight_scalar = self.checkTrace(foo_weight_scalar, (start, end)) + graph = ge_weight_scalar.graph_for(start, end) + self.assertAllFused(graph) - # TODO: uncomment when TE enables support for scalar tensors - # ge_weight_tensor = self.checkTrace(foo_weight_tensor, (start, end)) - # graph = ge_weight_tensor.graph_for(start, end) - # self.assertAllFused(graph) + # TODO: uncomment when TE enables support for scalar tensors + # ge_weight_tensor = self.checkTrace(foo_weight_tensor, (start, end)) + # graph = ge_weight_tensor.graph_for(start, end) + # self.assertAllFused(graph) - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") def test_concat_cuda(self): - hx = torch.randn(3, 20, dtype=torch.float, device='cuda') - cx = torch.randn(3, 20, dtype=torch.float, device='cuda') + for device in self.devices: + hx = torch.randn(3, 20, dtype=torch.float, device=device) + cx = torch.randn(3, 20, dtype=torch.float, device=device) - def foo(hx, cx): - return torch.cat((hx + cx, hx * cx)) + def foo(hx, cx): + return torch.cat((hx + cx, hx * cx)) - ge = self.checkTrace(foo, (hx, cx)) - graph = ge.graph_for(hx, cx) - self.assertAllFused(graph) - # XXX: TE fuser can handle concats in a fusion group. - # FileCheck().check("FusedConcat").check_next("return").run(str(graph)) + ge = self.checkTrace(foo, (hx, cx)) + graph = ge.graph_for(hx, cx) + self.assertAllFused(graph) + # XXX: TE fuser can handle concats in a fusion group. + # FileCheck().check("FusedConcat").check_next("return").run(str(graph)) @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") def test_remove_output_used_only_in_size(self): @@ -674,13 +674,13 @@ def fn(x, y, z): def fn_test_exp(x, y): return (x + .5 * y).exp() - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - def test_exp_cuda(self): - x = torch.randn(4, 4, dtype=torch.float, device='cuda') - y = torch.randn(4, 4, dtype=torch.float, device='cuda') + def test_exp(self): + for device in self.devices: + x = torch.randn(4, 4, dtype=torch.float, device=device) + y = torch.randn(4, 4, dtype=torch.float, device=device) - ge = self.checkTrace(self.fn_test_exp, (x, y)) - self.assertAllFused(ge.graph_for(x, y)) + ge = self.checkTrace(self.fn_test_exp, (x, y)) + self.assertAllFused(ge.graph_for(x, y)) @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "broken with profiling on") @@ -738,14 +738,14 @@ def test_norm_decompose(nm, in_opt_graph, not_in_opt_graph, in_fusegraph): test_norm_decompose(lm, ['aten::batch_norm_stats'], ['aten::layer_norm('], ['aten::sub', 'aten::mul', 'aten::add']) - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") def test_threshold(self): - def f(x): - return torch.threshold(x, 0, -10) + x + x + x + for device in self.devices: + def f(x): + return torch.threshold(x, 0, -10) + x + x + x - x = torch.tensor([-1, -0.5, 0, 1, 2, 3], device='cuda') - scripted = self.checkScript(f, (x,)) - self.assertAllFused(scripted.graph_for(x)) + x = torch.tensor([-1, -0.5, 0, 1, 2, 3], device=device) + scripted = self.checkScript(f, (x,)) + self.assertAllFused(scripted.graph_for(x)) @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") def test_scalar_arg_cuda(self): @@ -896,25 +896,10 @@ def doit(x, y): ge = self.checkTrace(doit, (x, y)) self.assertAllFused(ge.graph_for(x, y)) - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - def test_lstm_cuda(self): - inputs = get_lstm_inputs('cuda', training=True) - module = self.checkScript(LSTMCellS, inputs) - return - forward_graph = module.graph_for(*inputs) - self.assertGraphContainsExactly( - forward_graph, FUSION_GROUP, 1, consider_subgraphs=True) - self.assertTrue(len(strip_profiling_nodes(forward_graph.nodes())) == 2) - # Everything is differentiable but TupleConstruct return - FileCheck().check("DifferentiableGraph").check_next("TupleConstruct") \ - .check_next("return").run(str(forward_graph)) - - with enable_profiling_mode_for_profiling_tests(True): - hy, cy = module(*inputs) - warmup_backward((hy + cy).sum()) - backward = backward_graph(module) - self.assertAllFused(backward, except_for=("aten::t", "aten::mm", - "aten::_grad_sum_to_size")) + def test_lstm(self): + for device in self.devices: + inputs = get_lstm_inputs(device, training=True) + module = self.checkScript(LSTMCellS, inputs) @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") def test_lstm_concat_cuda(self): @@ -924,27 +909,27 @@ def test_lstm_concat_cuda(self): # XXX: TE fuser can handle concats inside a fusion group. # FileCheck().check("FusedConcat").check_next("return").run(str(graph)) - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - def test_lstm_gates_permutations_cuda(self): - # lstm has gates = x.mm(w_ih.t()) + hx.mm(w_hh.t()) + b_ih + b_hh. - # Test that any permutation of this will still result in one FusionGroup. - choices = ['x.mm(w_ih.t())', 'hx.mm(w_hh.t())', 'b_ih', 'b_hh'] - template = dedent(''' - def cell(x, hx, cx, w_ih, w_hh, b_ih, b_hh): - gates = {} + {} + {} + {} - ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) - return ingate * forgetgate * cellgate * outgate - ''') - for permutation in permutations(choices, len(choices)): - code = template.format(*permutation) - scope = {} - exec(code, globals(), scope) - cu = torch.jit.CompilationUnit(code) - - inputs = get_lstm_inputs('cuda', training=False) - self.assertEqual(cu.cell(*inputs), scope['cell'](*inputs)) - forward_graph = cu.cell.graph_for(*inputs) - self.assertGraphContainsExactly(forward_graph, FUSION_GROUP, 1) + def test_lstm_gates_permutations(self): + for device in self.devices: + # lstm has gates = x.mm(w_ih.t()) + hx.mm(w_hh.t()) + b_ih + b_hh. + # Test that any permutation of this will still result in one FusionGroup. + choices = ['x.mm(w_ih.t())', 'hx.mm(w_hh.t())', 'b_ih', 'b_hh'] + template = dedent(''' + def cell(x, hx, cx, w_ih, w_hh, b_ih, b_hh): + gates = {} + {} + {} + {} + ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) + return ingate * forgetgate * cellgate * outgate + ''') + for permutation in permutations(choices, len(choices)): + code = template.format(*permutation) + scope = {} + exec(code, globals(), scope) + cu = torch.jit.CompilationUnit(code) + + inputs = get_lstm_inputs(device, training=False) + self.assertEqual(cu.cell(*inputs), scope['cell'](*inputs)) + forward_graph = cu.cell.graph_for(*inputs) + self.assertGraphContainsExactly(forward_graph, FUSION_GROUP, 1) # TODO: Fuser doesn't work at all when inputs require grad. Fix that @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") From 346a71d29cf1832b9331a74c0f9e48b446a253a1 Mon Sep 17 00:00:00 2001 From: Elias Ellison Date: Thu, 12 Nov 2020 11:06:50 -0800 Subject: [PATCH 66/93] [NNC] More cpu tests (#47372) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47372 Test Plan: Imported from OSS Reviewed By: ansley Differential Revision: D24805254 Pulled By: eellison fbshipit-source-id: b7e5ee044ef816e024b6fc5c4041fff5f2049bb3 --- test/test_jit_fuser_te.py | 303 ++++++++++----------- torch/csrc/jit/tensorexpr/llvm_codegen.cpp | 4 +- 2 files changed, 147 insertions(+), 160 deletions(-) diff --git a/test/test_jit_fuser_te.py b/test/test_jit_fuser_te.py index 233399e63cf4..1139c90664b4 100644 --- a/test/test_jit_fuser_te.py +++ b/test/test_jit_fuser_te.py @@ -284,23 +284,23 @@ def test_chunk_correctness(self): def test_chunk_correctness_cuda(self): return self._test_chunk_correctness(self, 'cuda') - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - def test_chunk_distributes_cuda(self): - def f(x, y): - z1, z2 = (x + y).chunk(2, dim=1) - return z1 * z2 + def test_chunk_distributes(self): + for device in self.devices: + def f(x, y): + z1, z2 = (x + y).chunk(2, dim=1) + return z1 * z2 - x = torch.randn(4, 4, dtype=torch.float, device='cuda') - y = torch.randn(4, 4, dtype=torch.float, device='cuda') + x = torch.randn(4, 4, dtype=torch.float, device=device) + y = torch.randn(4, 4, dtype=torch.float, device=device) - ge = self.checkTrace(f, (x, y)) - graph = ge.graph_for(x, y) - # XXX: The old fuser does broadcast_tensors but the new fuser doesn't. - # FileCheck().check("broadcast_tensors").check('with ' + FUSION_GROUP + '_') \ - # .check_count('ConstantChunk', 2, exactly=True).run(str(graph)) - FileCheck().check("with " + FUSION_GROUP + "_").check_count( - "ConstantChunk", 1, exactly=True - ).run(str(graph)) + ge = self.checkTrace(f, (x, y)) + graph = ge.graph_for(x, y) + # XXX: The old fuser does broadcast_tensors but the new fuser doesn't. + # FileCheck().check("broadcast_tensors").check('with ' + FUSION_GROUP + '_') \ + # .check_count('ConstantChunk', 2, exactly=True).run(str(graph)) + FileCheck().check("with " + FUSION_GROUP + "_").check_count( + "ConstantChunk", 1, exactly=True + ).run(str(graph)) def test_chunk_motion_deduplicates_inputs(self): for device in self.devices: @@ -617,7 +617,7 @@ def foo_weight_tensor(start, end): # graph = ge_weight_tensor.graph_for(start, end) # self.assertAllFused(graph) - def test_concat_cuda(self): + def test_concat(self): for device in self.devices: hx = torch.randn(3, 20, dtype=torch.float, device=device) cx = torch.randn(3, 20, dtype=torch.float, device=device) @@ -651,24 +651,24 @@ def test_fuse(a, b): # the if node and the fusion group inside it should only have one output self.assertEqual(len(list(if_nodes[0].outputs())), 1) - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - def test_concat_invariant_cuda(self): - # Invariant: the output of prim::FusedConcat may - # not be an input to any node inside the FusionGroup. - def fn(x, y, z): - x1 = x + y - y1 = x - y - w = torch.cat([x1, y1]) - return w + z - - x = torch.randn(2, 2, dtype=torch.float, device='cuda') - y = torch.randn(2, 2, dtype=torch.float, device='cuda') - z = torch.randn(4, 2, dtype=torch.float, device='cuda') - ge = self.checkTrace(fn, (x, y, z)) - graph = ge.graph_for(x, y, z) - self.assertAllFused(graph, except_for={'aten::add'}) - # XXX: TE fuser can handle concats inside a fusion group. - # FileCheck().check("FusedConcat").check_next("return").run(str(graph)) + def test_concat_invariant(self): + for device in self.devices: + # Invariant: the output of prim::FusedConcat may + # not be an input to any node inside the FusionGroup. + def fn(x, y, z): + x1 = x + y + y1 = x - y + w = torch.cat([x1, y1]) + return w + z + + x = torch.randn(2, 2, dtype=torch.float, device=device) + y = torch.randn(2, 2, dtype=torch.float, device=device) + z = torch.randn(4, 2, dtype=torch.float, device=device) + ge = self.checkTrace(fn, (x, y, z)) + graph = ge.graph_for(x, y, z) + self.assertAllFused(graph, except_for={'aten::add'}) + # XXX: TE fuser can handle concats inside a fusion group. + # FileCheck().check("FusedConcat").check_next("return").run(str(graph)) @staticmethod def fn_test_exp(x, y): @@ -747,29 +747,29 @@ def f(x): scripted = self.checkScript(f, (x,)) self.assertAllFused(scripted.graph_for(x)) - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - def test_scalar_arg_cuda(self): - def fn_test_scalar_arg(x, p): - # type: (Tensor, float) -> Tensor - return p * (x * x + x) + def test_scalar_arg(self): + for device in self.devices: + def fn_test_scalar_arg(x, p): + # type: (Tensor, float) -> Tensor + return p * (x * x + x) - x = torch.randn(4, 4, dtype=torch.float, device='cuda') - p = 3 - scripted = self.checkScript(fn_test_scalar_arg, (x, p)) - self.assertAllFused(scripted.graph_for(x, p)) + x = torch.randn(4, 4, dtype=torch.float, device=device) + p = 3 + scripted = self.checkScript(fn_test_scalar_arg, (x, p)) + self.assertAllFused(scripted.graph_for(x, p)) - x.requires_grad_(True) + x.requires_grad_(True) - # use another function otherwise we will bailout - # and won't be able to do fused checks - def fn_test_scalar_arg_requires_grad(x, p): - # type: (Tensor, float) -> Tensor - return p * (x * x + x) + # use another function otherwise we will bailout + # and won't be able to do fused checks + def fn_test_scalar_arg_requires_grad(x, p): + # type: (Tensor, float) -> Tensor + return p * (x * x + x) - scripted = torch.jit.script(fn_test_scalar_arg_requires_grad) - out = scripted(x, p) - self.assertAllFused(scripted.graph_for(x, p), except_for=("aten::size", "prim::BroadcastSizes", - "aten::_size_if_not_equal")) + scripted = torch.jit.script(fn_test_scalar_arg_requires_grad) + out = scripted(x, p) + self.assertAllFused(scripted.graph_for(x, p), except_for=("aten::size", "prim::BroadcastSizes", + "aten::_size_if_not_equal")) @unittest.skip("deduplicating introduces aliasing in backward graph's outputs") def test_fuser_deduplication(self): @@ -900,14 +900,16 @@ def test_lstm(self): for device in self.devices: inputs = get_lstm_inputs(device, training=True) module = self.checkScript(LSTMCellS, inputs) + self.assertLastGraphAllFused() - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - def test_lstm_concat_cuda(self): - inputs = get_lstm_inputs('cuda') - ge = self.checkTrace(LSTMCellC, inputs) - graph = ge.graph_for(*inputs) - # XXX: TE fuser can handle concats inside a fusion group. - # FileCheck().check("FusedConcat").check_next("return").run(str(graph)) + def test_lstm_concat(self): + for device in self.devices: + inputs = get_lstm_inputs(device) + ge = self.checkTrace(LSTMCellC, inputs) + graph = ge.graph_for(*inputs) + self.assertLastGraphAllFused() + # XXX: TE fuser can handle concats inside a fusion group. + # FileCheck().check("FusedConcat").check_next("return").run(str(graph)) def test_lstm_gates_permutations(self): for device in self.devices: @@ -932,42 +934,26 @@ def cell(x, hx, cx, w_ih, w_hh, b_ih, b_hh): self.assertGraphContainsExactly(forward_graph, FUSION_GROUP, 1) # TODO: Fuser doesn't work at all when inputs require grad. Fix that - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - def test_lstm_traced_cuda(self): - inputs = get_lstm_inputs('cuda') - ge = self.checkTrace(LSTMCellF, inputs) - graph = ge.graph_for(*inputs) - fusion_groups = self.findFusionGroups(graph) - self.assertEqual(len(fusion_groups), 1) - FileCheck().check("Chunk").check("aten::sigmoid").check("aten::tanh").run(str(fusion_groups[0])) - - @unittest.skip("Test is flaky, see https://github.com/pytorch/pytorch/issues/8746") - def test_lstm_traced_cpu(self): - inputs = get_lstm_inputs('cpu') - try: + def test_lstm_traced(self): + for device in self.devices: + inputs = get_lstm_inputs(device) ge = self.checkTrace(LSTMCellF, inputs) graph = ge.graph_for(*inputs) - FileCheck.check("FusionGroup").run(str(graph)) - except RuntimeError as e: - if 'Failed to compile' in e.args[0]: - warnings.warn('CPU fuser test has failed! This is not a hard failure, ' - 'because the kernels sometimes trigger bugs in compilers ' - '(most notably GCC 7.2).') - raise unittest.SkipTest('Failed to compile') from e - else: - raise + fusion_groups = self.findFusionGroups(graph) + self.assertEqual(len(fusion_groups), 1) + FileCheck().check("Chunk").check("aten::sigmoid").check("aten::tanh").run(str(fusion_groups[0])) - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - def test_milstm_cuda(self): - inputs = get_milstm_inputs('cuda', training=True) - module = self.checkScript(MiLSTMCell, inputs) - forward_graph = module.graph_for(*inputs) - self.assertGraphContainsExactly( - forward_graph, FUSION_GROUP, 1, consider_subgraphs=True) - FileCheck().check("DifferentiableGraph").check_next("TupleConstruct") \ - .check_next("return").check(FUSION_GROUP).run(str(forward_graph)) - hy, cy = module(*inputs) - warmup_backward((hy + cy).sum()) + def test_milstm(self): + for device in self.devices: + inputs = get_milstm_inputs(device, training=True) + module = self.checkScript(MiLSTMCell, inputs) + forward_graph = module.graph_for(*inputs) + self.assertGraphContainsExactly( + forward_graph, FUSION_GROUP, 1, consider_subgraphs=True) + FileCheck().check("DifferentiableGraph").check_next("TupleConstruct") \ + .check_next("return").check(FUSION_GROUP).run(str(forward_graph)) + hy, cy = module(*inputs) + warmup_backward((hy + cy).sum()) @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") @unittest.skip("rand_like is not supported yet") @@ -1000,26 +986,26 @@ def create(self, x): def fn_test_relu(x, y): return F.relu(x + .5 * y) - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - def test_relu_cuda(self): - x = torch.randn(4, 4, dtype=torch.float, device='cuda') - y = torch.randn(4, 4, dtype=torch.float, device='cuda') + def test_relu(self): + for device in self.devices: + x = torch.randn(4, 4, dtype=torch.float, device=device) + y = torch.randn(4, 4, dtype=torch.float, device=device) - ge = self.checkTrace(self.fn_test_relu, (x, y)) - self.assertAllFused(ge.graph_for(x, y)) + ge = self.checkTrace(self.fn_test_relu, (x, y)) + self.assertAllFused(ge.graph_for(x, y)) - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - def test_erf_cuda(self): - def fn_test_erf(x): - return F.relu(torch.erf(x) - torch.erfc(x)) + def test_erf(self): + for device in self.devices: + def fn_test_erf(x): + return F.relu(torch.erf(x) - torch.erfc(x)) - x = torch.randn(4, 4, dtype=torch.float, device='cuda') - ge = self.checkTrace(fn_test_erf, (x,)) - self.assertAllFused(ge.graph_for(x)) - x.requires_grad_(True) - ge = self.checkTrace(fn_test_erf, (x,)) - self.assertAllFused(ge.graph_for(x), except_for=("aten::size", "prim::BroadcastSizes", - "aten::_size_if_not_equal")) + x = torch.randn(4, 4, dtype=torch.float, device=device) + ge = self.checkTrace(fn_test_erf, (x,)) + self.assertAllFused(ge.graph_for(x)) + x.requires_grad_(True) + ge = self.checkTrace(fn_test_erf, (x,)) + self.assertAllFused(ge.graph_for(x), except_for=("aten::size", "prim::BroadcastSizes", + "aten::_size_if_not_equal")) @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") @unittest.skip("rand_like is not supported yet") @@ -1082,67 +1068,68 @@ def fn(x, y): ge = self.checkScript(fn, (x, y)) self.assertAllFused(ge.graph_for(x, y)) - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - def test_small_constant_cuda(self): - def fn_test_small_constant(x, y): - return (1e-8 * x + 5e-9 * y) * 1e8 - x = torch.randn(4, 4, dtype=torch.float, device='cuda') - y = torch.randn(4, 4, dtype=torch.float, device='cuda') + def test_small_constant(self): + for device in self.devices: + def fn_test_small_constant(x, y): + return (1e-8 * x + 5e-9 * y) * 1e8 + x = torch.randn(4, 4, dtype=torch.float, device=device) + y = torch.randn(4, 4, dtype=torch.float, device=device) - ge = self.checkTrace(fn_test_small_constant, (x, y)) - self.assertAllFused(ge.graph_for(x, y)) + ge = self.checkTrace(fn_test_small_constant, (x, y)) + self.assertAllFused(ge.graph_for(x, y)) - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") # Currently we don't pull constants into fusion groups, because in some # cases it could remove the constant from the original graph and now our # fusion group needs to return that constant for its other users. # Instead of never pulling constants into the fusion group, we should just # be more careful at how we rewrite its users. # TODO: fix that and reenable the test. - def test_tensor_scalar_ops_cuda(self): - def should_fuse(x): - z = 3. - y = x + z - return x * y - - def should_fuse_scalar(x, z): - y = x + int(z) - return x * y - - inputs = [torch.randn(2, 2, dtype=torch.float, device='cuda')] - ge = self.checkScript(should_fuse, inputs) - graph = ge.graph_for(*inputs) - fusion_groups = self.findFusionGroups(graph) - self.assertEqual(len(fusion_groups), 1) - FileCheck().check("aten::add").check("aten::mul").run(str(fusion_groups[0])) + def test_tensor_scalar_ops(self): + for device in self.devices: + def should_fuse(x): + z = 3. + y = x + z + return x * y - inputs = [ - torch.randn(2, 2, dtype=torch.float, device='cuda'), - torch.tensor(3., dtype=torch.float, device='cuda'), - ] - ge = self.checkScript(should_fuse_scalar, inputs) - # Check that the fused graph computes correct results when the scalar - # input changes. - inputs = [ - torch.randn(2, 2, dtype=torch.float, device='cuda'), - torch.tensor(7., dtype=torch.float, device='cuda'), - ] - self.assertEqual(ge(*inputs), should_fuse_scalar(*inputs)) - # The TE fuser supports fusion of non-constant scalars - self.assertGraphContainsExactly( - ge.graph_for(*inputs), FUSION_GROUP, 1, consider_subgraphs=True) + def should_fuse_scalar(x, z): + y = x + int(z) + return x * y + + inputs = [torch.randn(2, 2, dtype=torch.float, device=device)] + ge = self.checkScript(should_fuse, inputs) + graph = ge.graph_for(*inputs) + fusion_groups = self.findFusionGroups(graph) + self.assertEqual(len(fusion_groups), 1) + FileCheck().check("aten::add").check("aten::mul").run(str(fusion_groups[0])) + + inputs = [ + torch.randn(2, 2, dtype=torch.float, device=device), + torch.tensor(3., dtype=torch.float, device=device), + ] + ge = self.checkScript(should_fuse_scalar, inputs) + # Check that the fused graph computes correct results when the scalar + # input changes. + inputs = [ + torch.randn(2, 2, dtype=torch.float, device=device), + torch.tensor(7., dtype=torch.float, device=device), + ] + self.assertEqual(ge(*inputs), should_fuse_scalar(*inputs)) + # The TE fuser supports fusion of non-constant scalars + self.assertGraphContainsExactly( + ge.graph_for(*inputs), FUSION_GROUP, 1, consider_subgraphs=True) def test_where_and_typing(self): - def f(x, y): - mask = x > y - res = torch.where(mask, x, y) - return mask, res + for device in self.devices: + def f(x, y): + mask = x > y + res = torch.where(mask, x, y) + return mask, res - x = torch.randn(4, 4, dtype=torch.double) - y = torch.randn(4, 4, dtype=torch.double) + x = torch.randn(4, 4, dtype=torch.double, device=device) + y = torch.randn(4, 4, dtype=torch.double, device=device) - script_f = self.checkScript(f, (x, y)) - self.assertAllFused(script_f.graph_for(x, y), except_for={'prim::TupleConstruct'}) + script_f = self.checkScript(f, (x, y)) + self.assertAllFused(script_f.graph_for(x, y), except_for={'prim::TupleConstruct'}) @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "no half support with profiling on") diff --git a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp index 98c4df86ea4c..e98fdb890e83 100644 --- a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp +++ b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp @@ -401,12 +401,12 @@ class LLVMIntrinsicsExpander : public GenericIntrinsicsExpander { if (v->op_type() == kTanh) { ScalarType stype = v->dtype().scalar_type(); if (stype == ScalarType::Float) { - return fast_tanh(v->param(0)); + return fast_tanh(v->param(0)->accept_mutator(this)); } } else if (v->op_type() == kSigmoid) { ScalarType stype = v->dtype().scalar_type(); if (stype == ScalarType::Float) { - return fast_sigmoid(v->param(0)); + return fast_sigmoid(v->param(0)->accept_mutator(this)); } } // TODO: fast exp From dcca712d3c97840872ce25b728bd61a1cdb2d12b Mon Sep 17 00:00:00 2001 From: Elias Ellison Date: Thu, 12 Nov 2020 11:06:50 -0800 Subject: [PATCH 67/93] [NNC] refactor cuda half support to more general file (#47373) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47373 Test Plan: Imported from OSS Reviewed By: ansley Differential Revision: D24805246 Pulled By: eellison fbshipit-source-id: 33b5c84c9212d51bac3968e02aae2434dde40cd8 --- torch/csrc/jit/tensorexpr/cuda_codegen.cpp | 7 ++++--- .../tensorexpr/{cuda_half_support.h => half_support.h} | 9 +++++---- 2 files changed, 9 insertions(+), 7 deletions(-) rename torch/csrc/jit/tensorexpr/{cuda_half_support.h => half_support.h} (93%) diff --git a/torch/csrc/jit/tensorexpr/cuda_codegen.cpp b/torch/csrc/jit/tensorexpr/cuda_codegen.cpp index b23bc3f247ec..c674c0559bca 100644 --- a/torch/csrc/jit/tensorexpr/cuda_codegen.cpp +++ b/torch/csrc/jit/tensorexpr/cuda_codegen.cpp @@ -1,8 +1,9 @@ #include -#include +#include #include #include +#include #include #include #include @@ -922,7 +923,7 @@ void CudaCodeGen::Initialize() { // Check whether the statement uses the Half type, if so add the // half_support_literal. Stmt* stmt_v = stmt(); - CudaHalfChecker halfChecker; + HalfChecker halfChecker; stmt_v->accept(&halfChecker); if (halfChecker.hasHalf()) { os() << fuser::cuda::half_support_literal << std::endl; @@ -991,7 +992,7 @@ void CudaCodeGen::Initialize() { stmt_v = stmt_v->accept_mutator(&prioritize_load); // The registerizer might insert half-type scalars, we don't want this. - CudaHalfRewriter hsFix; + HalfRewriter hsFix; stmt_v = stmt_v->accept_mutator(&hsFix); stmt_v = IRSimplifier::simplify(stmt_v); diff --git a/torch/csrc/jit/tensorexpr/cuda_half_support.h b/torch/csrc/jit/tensorexpr/half_support.h similarity index 93% rename from torch/csrc/jit/tensorexpr/cuda_half_support.h rename to torch/csrc/jit/tensorexpr/half_support.h index 79514da7b4fc..c20d355af9d7 100644 --- a/torch/csrc/jit/tensorexpr/cuda_half_support.h +++ b/torch/csrc/jit/tensorexpr/half_support.h @@ -1,14 +1,15 @@ #pragma once -#include -#include +#include +#include +#include namespace torch { namespace jit { namespace tensorexpr { // Walk the Statment looking for Half size loads/stores. -class CudaHalfChecker : public IRVisitor { +class HalfChecker : public IRVisitor { public: bool hasHalf() { return hasHalf_; @@ -37,7 +38,7 @@ class CudaHalfChecker : public IRVisitor { bool hasHalf_{false}; }; -class CudaHalfRewriter : public IRMutator { +class HalfRewriter : public IRMutator { const Expr* mutate(const Load* v) override { const Expr* child = IRMutator::mutate(v); if (child->dtype().scalar_type() != ScalarType::Half) { From 664d2f48cf5383b5d9ee8334c1c3ea59b971d066 Mon Sep 17 00:00:00 2001 From: Elias Ellison Date: Thu, 12 Nov 2020 11:06:50 -0800 Subject: [PATCH 68/93] [NNC] Enable unary op cpu testing (#47374) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47374 A few small fixes needed to enable unary op cpu testing. If reviewers would prefer I split them up let me know. Test Plan: Imported from OSS Reviewed By: ansley Differential Revision: D24805248 Pulled By: eellison fbshipit-source-id: c2cfe2e3319a633e64da3366e68f5bf21d390cb7 --- test/test_jit_fuser_te.py | 98 ++++++++++++---------- torch/_C/__init__.pyi.in | 2 + torch/csrc/jit/tensorexpr/codegen.cpp | 6 +- torch/csrc/jit/tensorexpr/expr.cpp | 8 ++ torch/csrc/jit/tensorexpr/expr.h | 2 + torch/csrc/jit/tensorexpr/kernel.cpp | 2 +- torch/csrc/jit/tensorexpr/llvm_codegen.cpp | 65 +++++++++----- torch/csrc/jit/tensorexpr/llvm_jit.cpp | 6 ++ torch/testing/_internal/jit_utils.py | 8 ++ 9 files changed, 131 insertions(+), 66 deletions(-) diff --git a/test/test_jit_fuser_te.py b/test/test_jit_fuser_te.py index 1139c90664b4..d1a0e3e366d9 100644 --- a/test/test_jit_fuser_te.py +++ b/test/test_jit_fuser_te.py @@ -18,7 +18,7 @@ from torch.testing._internal.common_utils import run_tests, ProfilingMode, GRAPH_EXECUTOR, \ enable_profiling_mode_for_profiling_tests from torch.testing._internal.jit_utils import JitTestCase, _inline_everything, \ - RUN_CUDA, RUN_CUDA_HALF, RUN_CUDA_MULTI_GPU, warmup_backward + RUN_CUDA, RUN_CUDA_HALF, RUN_CUDA_MULTI_GPU, warmup_backward, set_fusion_group_inlining from textwrap import dedent from itertools import product, permutations @@ -29,6 +29,7 @@ from torch.testing._internal.te_utils import CudaCodeGenExecuted FUSION_GROUP = 'prim::TensorExprGroup' +LLVM_ENABLED = torch._C._llvm_enabled() def strip_profiling_nodes(nodes): profiling_opcodes = set(['prim::BailoutTemplate', 'prim::BailOut']) @@ -63,6 +64,9 @@ def setUp(self): self.old_profiling_executor = torch._C._jit_set_profiling_executor(True) self.old_profiling_mode = torch._C._jit_set_profiling_mode(True) + self.old_fusion_inlining = torch._C._debug_get_fusion_group_inlining() + torch._C._debug_set_fusion_group_inlining(False) + self.texpr_fuser_state = torch._C._jit_texpr_fuser_enabled() torch._C._jit_set_texpr_fuser_enabled(True) @@ -75,6 +79,7 @@ def tearDown(self): torch._C._jit_override_can_fuse_on_gpu(self.old_gpu_fuser_state) torch._C._jit_override_can_fuse_on_cpu(self.old_cpu_fuser_state) torch._C._jit_set_te_must_use_llvm_cpu(self.old_must_use_cpu_state) + torch._C._debug_set_fusion_group_inlining(self.old_fusion_inlining) torch._C._jit_set_texpr_fuser_enabled(self.texpr_fuser_state) @@ -219,21 +224,23 @@ def test_cuda_half(self): self.assertEqual(grads_half, fusion_grads) def test_checks_cat_inputs(self): - for device in self.devices: - # We shouldn't treat cat nodes as broadcasting. All their inputs - # need to be checked for having the same map size, before we can - # run the kernel. - def f(x, y): - return torch.cat([x + 2 * x + x ** 2, y + 4 * y + y ** 3], dim=0) - - # NOTE: y is broadcastable to x, but output of f(x, y) should have - # shape 3x4, and not 4x4. - x = torch.randn(2, 4, dtype=torch.float, device=device) - y = torch.randn(1, 4, dtype=torch.float, device=device) - - scripted = self.checkScript(f, (x, y)) - self.assertEqual(scripted(x, y).shape, (3, 4)) - self.assertAllFused(scripted.graph_for(x, y)) + # single fusion node causes error + with set_fusion_group_inlining(True): + for device in self.devices: + # We shouldn't treat cat nodes as broadcasting. All their inputs + # need to be checked for having the same map size, before we can + # run the kernel. + def f(x, y): + return torch.cat([x + 2 * x + x ** 2, y + 4 * y + y ** 3], dim=0) + + # NOTE: y is broadcastable to x, but output of f(x, y) should have + # shape 3x4, and not 4x4. + x = torch.randn(2, 4, dtype=torch.float, device=device) + y = torch.randn(1, 4, dtype=torch.float, device=device) + + scripted = self.checkScript(f, (x, y)) + self.assertEqual(scripted(x, y).shape, (3, 4)) + self.assertAllFused(scripted.graph_for(x, y)) def test_chunk(self): for device in self.devices: @@ -445,7 +452,7 @@ def f(x, y, z): ge = self.checkTrace(f, (x, y, z), inputs_require_grads=False) self.assertAllFused(ge.graph_for(x, y, z)) - @unittest.skipIf(not torch._C._llvm_enabled(), "TODO: bugs in ir eval") + @unittest.skipIf(not LLVM_ENABLED, "TODO: bugs in ir eval") def test_bitwise_ops(self): def apply(fn): return lambda x, y, z: fn(fn(x, y), z) @@ -618,18 +625,20 @@ def foo_weight_tensor(start, end): # self.assertAllFused(graph) def test_concat(self): - for device in self.devices: - hx = torch.randn(3, 20, dtype=torch.float, device=device) - cx = torch.randn(3, 20, dtype=torch.float, device=device) + # disabling concat causes error with single concat node + with set_fusion_group_inlining(True): + for device in self.devices: + hx = torch.randn(3, 20, dtype=torch.float, device=device) + cx = torch.randn(3, 20, dtype=torch.float, device=device) - def foo(hx, cx): - return torch.cat((hx + cx, hx * cx)) + def foo(hx, cx): + return torch.cat((hx + cx, hx * cx)) - ge = self.checkTrace(foo, (hx, cx)) - graph = ge.graph_for(hx, cx) - self.assertAllFused(graph) - # XXX: TE fuser can handle concats in a fusion group. - # FileCheck().check("FusedConcat").check_next("return").run(str(graph)) + ge = self.checkTrace(foo, (hx, cx)) + graph = ge.graph_for(hx, cx) + self.assertAllFused(graph) + # XXX: TE fuser can handle concats in a fusion group. + # FileCheck().check("FusedConcat").check_next("return").run(str(graph)) @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") def test_remove_output_used_only_in_size(self): @@ -903,13 +912,15 @@ def test_lstm(self): self.assertLastGraphAllFused() def test_lstm_concat(self): - for device in self.devices: - inputs = get_lstm_inputs(device) - ge = self.checkTrace(LSTMCellC, inputs) - graph = ge.graph_for(*inputs) - self.assertLastGraphAllFused() - # XXX: TE fuser can handle concats inside a fusion group. - # FileCheck().check("FusedConcat").check_next("return").run(str(graph)) + # single fusion node causes error + with set_fusion_group_inlining(True): + for device in self.devices: + inputs = get_lstm_inputs(device) + ge = self.checkTrace(LSTMCellC, inputs) + graph = ge.graph_for(*inputs) + self.assertLastGraphAllFused() + # XXX: TE fuser can handle concats inside a fusion group. + # FileCheck().check("FusedConcat").check_next("return").run(str(graph)) def test_lstm_gates_permutations(self): for device in self.devices: @@ -1184,8 +1195,11 @@ def fn(a): torch._C._jit_override_can_fuse_on_cpu(old_cpu_fuser_state) - def data_for(self, dtype, device="cuda"): - v = torch.arange(1, 3, dtype=torch.float, device=device) + def data_for(self, dtype, device="cuda", size=None): + if size is None: + v = torch.arange(1, 3, dtype=torch.float, device=device) + else: + v = torch.rand(*size, device=device) if dtype == torch.bool: return v > 2 elif dtype in [torch.qint8, torch.quint8, torch.qint32]: @@ -1193,10 +1207,10 @@ def data_for(self, dtype, device="cuda"): else: return v.to(dtype) - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") + @unittest.skipIf(not LLVM_ENABLED, "TODO: bugs in ir eval") def test_unary_ops(self): def apply(fn): - return lambda x: fn(2 * x) + return lambda x: fn(x) dtypes = [ torch.int8, @@ -1239,10 +1253,10 @@ def apply(fn): torch.trunc, torch.frac, ] - devices = ["cuda"] - for dtype, op, device in product(dtypes, unary_ops, devices): + sizes = [(1,), (2,), (4, 4)] + for dtype, op, device, size in product(dtypes, unary_ops, self.devices, sizes): try: - x = self.data_for(dtype, device) + x = self.data_for(dtype, device, size=size) fn = apply(op) ref = fn(x) except Exception: @@ -1256,7 +1270,7 @@ def apply(fn): self.assertAllFused(t.graph_for(x)) except Exception as e: raise RuntimeError( - " ".join(["Failed:", str(dtype), op.__name__, device]) + " ".join(["Failed:", str(dtype), op.__name__, device, str(size)]) ) @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index a6af57152b9a..b05d588972f9 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -182,6 +182,8 @@ def _jit_pass_inline(Graph) -> None: ... def _jit_get_schemas_for_operator(name :str) -> List[FunctionSchema]: ... def _jit_can_fuse_on_cpu() -> _bool: ... def _jit_can_fuse_on_gpu() -> _bool: ... +def _debug_get_fusion_group_inlining() -> _bool: ... +def _debug_set_fusion_group_inlining(enable: _bool): ... def _jit_texpr_fuser_enabled() -> _bool: ... def _jit_nvfuser_enabled() -> _bool: ... def _llvm_enabled() -> _bool: ... diff --git a/torch/csrc/jit/tensorexpr/codegen.cpp b/torch/csrc/jit/tensorexpr/codegen.cpp index e08452172094..b8f16c50e05f 100644 --- a/torch/csrc/jit/tensorexpr/codegen.cpp +++ b/torch/csrc/jit/tensorexpr/codegen.cpp @@ -51,8 +51,10 @@ std::unique_ptr CreateCodeGen( const Expr* GenericIntrinsicsExpander::mutate(const Intrinsics* v) { if (v->op_type() == kSigmoid) { auto x = v->param(0)->accept_mutator(this); - auto one = ExprHandle(getImmediateByType(v->dtype(), 1.0)); - auto zero = ExprHandle(getImmediateByType(v->dtype(), 0.0)); + auto one = expr_to_vec( + ExprHandle(getImmediateByType(v->dtype(), 1.0)), v->dtype().lanes()); + auto zero = expr_to_vec( + ExprHandle(getImmediateByType(v->dtype(), 0.0)), v->dtype().lanes()); ExprHandle y = one / (one + exp(zero - ExprHandle(x))); return y.node(); } diff --git a/torch/csrc/jit/tensorexpr/expr.cpp b/torch/csrc/jit/tensorexpr/expr.cpp index acbe0879e896..30e8df53571b 100644 --- a/torch/csrc/jit/tensorexpr/expr.cpp +++ b/torch/csrc/jit/tensorexpr/expr.cpp @@ -219,6 +219,14 @@ ExprHandle Buf::make(const std::vector& dims, Dtype dtype) { return Buf::make("", dims, dtype); } +ExprHandle expr_to_vec(ExprHandle v, int lanes) { + if (lanes == 1) { + return v; + } else { + return Broadcast::make(v, lanes); + } +} + } // namespace tensorexpr } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/tensorexpr/expr.h b/torch/csrc/jit/tensorexpr/expr.h index 7c64403a10bd..9b8dd23db0b1 100644 --- a/torch/csrc/jit/tensorexpr/expr.h +++ b/torch/csrc/jit/tensorexpr/expr.h @@ -310,6 +310,8 @@ TORCH_API ExprHandle remainder(const ExprHandle& v1, const ExprHandle& v2); TORCH_API ExprHandle ifThenElse(const ExprHandle& c, const ExprHandle& t, const ExprHandle& f); +TORCH_API ExprHandle expr_to_vec(ExprHandle v, int lanes); + } // namespace tensorexpr } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp index 3c591ffef3e1..bdd74cefc9a0 100644 --- a/torch/csrc/jit/tensorexpr/kernel.cpp +++ b/torch/csrc/jit/tensorexpr/kernel.cpp @@ -420,7 +420,7 @@ ExprHandle TensorExprKernel::constant(const torch::jit::Value* v) { ExprHandle promoteIntegerToFloat(const ExprHandle& e) { auto scalarType = static_cast(e.dtype().scalar_type()); - if (!c10::isIntegralType(scalarType)) { + if (!c10::isIntegralType(scalarType, /*includeBool*/ true)) { return e; } auto defaultType = static_cast( diff --git a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp index e98fdb890e83..6797dd6f2ce5 100644 --- a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp +++ b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp @@ -1,6 +1,7 @@ #ifdef TORCH_ENABLE_LLVM #include +#include #include #include @@ -23,6 +24,7 @@ #endif #include +#include #include #include #include @@ -310,6 +312,11 @@ LLVMCodeGenImpl::LLVMCodeGenImpl( module_->setDataLayout(assertSuccess(JTMB.getDefaultDataLayoutForTarget())); module_->setTargetTriple(JTMB.getTargetTriple().str()); + // We support float16 ops by casting expr inputs to float32 + // and then casting the result back to float16 + HalfRewriter hsFix; + stmt = stmt->accept_mutator(&hsFix); + // Emit prototype and bind argument Vars to parameter indices. llvm::Type* retTy = dtypeToLLVM(dtype); std::vector params; @@ -424,25 +431,25 @@ class LLVMIntrinsicsExpander : public GenericIntrinsicsExpander { int lanes = dtype.lanes(); // TODO: use a dedicated bind-var to make sure v is not evalualted multiple // times. Clamp the input expression to [-9, 9] - ExprHandle plus_9 = to_vec(9.0f, lanes); - ExprHandle minus_9 = to_vec(-9.0f, lanes); + ExprHandle plus_9 = float_to_vec(9.0f, lanes); + ExprHandle minus_9 = float_to_vec(-9.0f, lanes); ExprHandle v1 = Min::make(v, plus_9, false); v1 = Max::make(v1, minus_9, false); // The coefficients for the numerator - ExprHandle alpha_1 = to_vec(4.89352455891786e-03f, lanes); - ExprHandle alpha_3 = to_vec(6.37261928875436e-04f, lanes); - ExprHandle alpha_5 = to_vec(1.48572235717979e-05f, lanes); - ExprHandle alpha_7 = to_vec(5.12229709037114e-08f, lanes); - ExprHandle alpha_9 = to_vec(-8.60467152213735e-11f, lanes); - ExprHandle alpha_11 = to_vec(2.00018790482477e-13f, lanes); - ExprHandle alpha_13 = to_vec(-2.76076847742355e-16f, lanes); + ExprHandle alpha_1 = float_to_vec(4.89352455891786e-03f, lanes); + ExprHandle alpha_3 = float_to_vec(6.37261928875436e-04f, lanes); + ExprHandle alpha_5 = float_to_vec(1.48572235717979e-05f, lanes); + ExprHandle alpha_7 = float_to_vec(5.12229709037114e-08f, lanes); + ExprHandle alpha_9 = float_to_vec(-8.60467152213735e-11f, lanes); + ExprHandle alpha_11 = float_to_vec(2.00018790482477e-13f, lanes); + ExprHandle alpha_13 = float_to_vec(-2.76076847742355e-16f, lanes); // The coeffecients for the denominator - ExprHandle beta_0 = to_vec(4.89352518554385e-03f, lanes); - ExprHandle beta_2 = to_vec(2.26843463243900e-03f, lanes); - ExprHandle beta_4 = to_vec(1.18534705686654e-04f, lanes); - ExprHandle beta_6 = to_vec(1.19825839466702e-06f, lanes); + ExprHandle beta_0 = float_to_vec(4.89352518554385e-03f, lanes); + ExprHandle beta_2 = float_to_vec(2.26843463243900e-03f, lanes); + ExprHandle beta_4 = float_to_vec(1.18534705686654e-04f, lanes); + ExprHandle beta_6 = float_to_vec(1.19825839466702e-06f, lanes); // numerator ExprHandle v2 = v1 * v1; @@ -467,20 +474,16 @@ class LLVMIntrinsicsExpander : public GenericIntrinsicsExpander { // sigmoid(x) = (tanh(x / 2) + 1) / 2 ExprHandle x{v_ptr}; int lanes = x.dtype().lanes(); - ExprHandle one_v = to_vec(1.f, lanes); - ExprHandle half_v = to_vec(0.5f, lanes); + ExprHandle one_v = float_to_vec(1.f, lanes); + ExprHandle half_v = float_to_vec(0.5f, lanes); ExprHandle x2 = x * half_v; ExprHandle y{fast_tanh(x2.node())}; ExprHandle z = (y + one_v) * half_v; return z.node(); } - ExprHandle to_vec(float v, int lanes) { - if (lanes == 1) { - return v; - } else { - return Broadcast::make(v, lanes); - } + ExprHandle float_to_vec(float v, int lanes) { + return expr_to_vec(FloatImm::make(v), lanes); } }; @@ -1627,6 +1630,26 @@ void LLVMCodeGenImpl::visit(const Intrinsics* v) { throw unimplemented_lowering(v); } break; } + } else if (v->dtype().is_integral() && v->op_type() == kFabs) { + // abs is only intrinsic defined for integer inputs in pytorch eager + v->params().front()->accept(this); + if (is_unsigned_integral(v->dtype().scalar_type())) { + return; + } + // TODO: use llvm.abs intrinsic for LLVM 12 + auto zero = llvm::ConstantInt::get(value_->getType(), 0); + auto neg_value = irb_.CreateSub(zero, value_); + auto icmp = irb_.CreateICmpSGT(value_, zero); + value_ = irb_.CreateSelect(icmp, value_, neg_value); + return; + } else { + TORCH_INTERNAL_ASSERT( + false, + v, + "Unimplemented lowering:", + v->op_type(), + " for input of dtype", + v->dtype().scalar_dtype()); } std::vector params; diff --git a/torch/csrc/jit/tensorexpr/llvm_jit.cpp b/torch/csrc/jit/tensorexpr/llvm_jit.cpp index f2161edde56a..fc487ee9dfcf 100644 --- a/torch/csrc/jit/tensorexpr/llvm_jit.cpp +++ b/torch/csrc/jit/tensorexpr/llvm_jit.cpp @@ -18,6 +18,8 @@ #include #include +#include + #include #include #include @@ -53,6 +55,10 @@ static void registerIntrinsics( entry("atan2f", &atan2f), entry("fmodf", &fmodf), entry("remainderf", &remainderf), + // float -> half & half -> float conversions + entry("__gnu_h2f_ieee", &c10::detail::fp16_ieee_to_fp32_value), + entry("__gnu_f2h_ieee", &c10::detail::fp16_ieee_from_fp32_value), + // FP32 Sleef functions -- SSE entry("Sleef_acosf4", &Sleef_acosf4_u10), entry("Sleef_asinf4", &Sleef_asinf4_u10), diff --git a/torch/testing/_internal/jit_utils.py b/torch/testing/_internal/jit_utils.py index c24c2b15c553..289b73f43118 100644 --- a/torch/testing/_internal/jit_utils.py +++ b/torch/testing/_internal/jit_utils.py @@ -647,6 +647,14 @@ def inline_everything_mode(should_inline): finally: torch._C._jit_set_inline_everything_mode(old) +@contextmanager +def set_fusion_group_inlining(inlining): + old = torch._C._debug_get_fusion_group_inlining() + torch._C._debug_set_fusion_group_inlining(inlining) + try: + yield + finally: + torch._C._debug_set_fusion_group_inlining(old) # note: not re-entrant, use unnested only @contextmanager From 76ff557de76d31b5e5908cf4a0ce3abb8e34f867 Mon Sep 17 00:00:00 2001 From: Nick Gibson Date: Thu, 12 Nov 2020 11:32:23 -0800 Subject: [PATCH 69/93] [NNC] add hazard analysis to Bounds Inference (#47684) Summary: Adds a helper function to Bounds Inference / Memory Analaysis infrastructure which returns the kind of hazard found between two Stmts (e.g. Blocks or Loops). E.g. ``` for (int i = 0; i < 10; ++i) { A[x] = i * 2; } for (int j = 0; j < 10; ++j) { B[x] = A[x] / 2; } ``` The two loops have a `ReadAfterWrite` hazard, while in this example: ``` for (int i = 0; i < 10; ++i) { A[x] = i * 2; } for (int j = 0; j < 10; ++j) { A[x] = B[x] / 2; } ``` The loops have a `WriteAfterWrite` hazard. This isn't 100% of what we need for loop fusion, for example we don't check the strides of the loop to see if they match. Pull Request resolved: https://github.com/pytorch/pytorch/pull/47684 Reviewed By: malfet Differential Revision: D24873587 Pulled By: nickgg fbshipit-source-id: 991149e5942e769612298ada855687469a219d62 --- test/cpp/tensorexpr/test_boundsinference.cpp | 131 ++++++++++++++++++ test/cpp/tensorexpr/test_memdependency.cpp | 2 +- test/cpp/tensorexpr/tests.h | 4 + .../csrc/jit/tensorexpr/bounds_inference.cpp | 92 +++++++++++- torch/csrc/jit/tensorexpr/bounds_inference.h | 17 +++ torch/csrc/jit/tensorexpr/bounds_overlap.h | 10 ++ .../jit/tensorexpr/mem_dependency_checker.cpp | 48 ++++++- .../jit/tensorexpr/mem_dependency_checker.h | 21 ++- 8 files changed, 316 insertions(+), 9 deletions(-) diff --git a/test/cpp/tensorexpr/test_boundsinference.cpp b/test/cpp/tensorexpr/test_boundsinference.cpp index 4325f0af9a6f..c1943ae7099b 100644 --- a/test/cpp/tensorexpr/test_boundsinference.cpp +++ b/test/cpp/tensorexpr/test_boundsinference.cpp @@ -593,5 +593,136 @@ void testBoundsInferenceFlattened() { ASSERT_TRUE(exprEquals(TABI.stop[0], new IntImm(3 * 4 * 5 - 1))); } +void testGetPotentialHazards() { + KernelScope kernel_scope; + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + BufHandle c("C", {5}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + using namespace analysis; + + { + /* + * A[0] = B[0]; + * B[0] = 3; WAR on B + * A[0] = B[0]; WAW on A, RAW on B + * C[0] = 5; + */ + + Store* store1 = Store::make(a, {0}, Load::make(b, {0}, 1), 1); + Store* store2 = Store::make(b, {0}, 3, 1); + Store* store3 = Store::make(a, {0}, Load::make(b, {0}, 1), 1); + Store* store4 = Store::make(c, {0}, 5, 1); + Stmt* stmt = Block::make({store1, store2, store3, store4}); + + MemDependencyChecker analyzer; + stmt->accept(&analyzer); + + ASSERT_EQ( + HazardKind::WriteAfterRead, + getPotentialHazards(analyzer, store1, store2)); + + ASSERT_EQ( + HazardKind::ReadAfterWrite, + getPotentialHazards(analyzer, store2, store3)); + + ASSERT_EQ( + HazardKind::WriteAfterWrite, + getPotentialHazards(analyzer, store1, store3)); + + // Fourth store has no dependencies + ASSERT_EQ( + HazardKind::NoDependency, + getPotentialHazards(analyzer, store1, store4)); + ASSERT_EQ( + HazardKind::NoDependency, + getPotentialHazards(analyzer, store2, store4)); + ASSERT_EQ( + HazardKind::NoDependency, + getPotentialHazards(analyzer, store3, store4)); + } +} + +void testGetPotentialHazardsLoopNoHazard() { + KernelScope kernel_scope; + + Tensor* A = Compute( + "A", {{64, "i"}, {64, "j"}}, [](const VarHandle& i, const VarHandle& j) { + return i * j; + }); + Tensor* B = Compute( + "B", {{64, "i"}, {64, "j"}}, [](const VarHandle& i, const VarHandle& j) { + return (i + 1) * (j + 1); + }); + + LoopNest l({A, B}); + + using namespace analysis; + + MemDependencyChecker analyzer; + l.root_stmt()->accept(&analyzer); + + For* loopRootA = l.getLoopStmtsFor(A)[0]; + For* loopRootB = l.getLoopStmtsFor(B)[0]; + + // No dependencies between loops. + ASSERT_EQ( + HazardKind::NoDependency, + getPotentialHazards(analyzer, loopRootA, loopRootB)); +} + +void testGetPotentialHazardsLoopCall() { + KernelScope kernel_scope; + + Tensor* A = Compute( + "A", {{64, "i"}, {64, "j"}}, [](const VarHandle& i, const VarHandle& j) { + return i * j; + }); + Tensor* B = Compute( + "B", {{64, "i"}, {64, "j"}}, [&](const VarHandle& i, const VarHandle& j) { + return A->call(i, j) + 5; + }); + + LoopNest l({A, B}); + + using namespace analysis; + + MemDependencyChecker analyzer; + l.root_stmt()->accept(&analyzer); + + For* loopRootA = l.getLoopStmtsFor(A)[0]; + For* loopRootB = l.getLoopStmtsFor(B)[0]; + + ASSERT_EQ( + HazardKind::ReadAfterWrite, + getPotentialHazards(analyzer, loopRootA, loopRootB)); +} + +void testGetPotentialHazardsLoopSplit() { + KernelScope kernel_scope; + + Tensor* A = Compute( + "A", {{64, "i"}, {64, "j"}}, [](const VarHandle& i, const VarHandle& j) { + return i * j; + }); + + LoopNest l({A}); + For *outer, *inner, *tail; + + // Splitting with tail by something offset creates a tail which also writes to + // A. + l.splitWithTail(l.getLoopStmtsFor(A)[0], 5, &outer, &inner, &tail); + + using namespace analysis; + + MemDependencyChecker analyzer; + l.root_stmt()->accept(&analyzer); + + ASSERT_EQ( + HazardKind::WriteAfterWrite, getPotentialHazards(analyzer, outer, tail)); +} + } // namespace jit } // namespace torch diff --git a/test/cpp/tensorexpr/test_memdependency.cpp b/test/cpp/tensorexpr/test_memdependency.cpp index 3a0e70d98dc6..7b866747ca05 100644 --- a/test/cpp/tensorexpr/test_memdependency.cpp +++ b/test/cpp/tensorexpr/test_memdependency.cpp @@ -1112,7 +1112,7 @@ void testMemDependencyCheckerLoopSelfDependency() { // This check assumes that the Stmt has a single Store with a single Load on // the RHS. auto isSelfDependent = - [](const std::deque>& history) -> bool { + [](const std::vector>& history) -> bool { return history.front()->hasDependency(history.back()); }; diff --git a/test/cpp/tensorexpr/tests.h b/test/cpp/tensorexpr/tests.h index e3dd37235e05..ae58a31a538e 100644 --- a/test/cpp/tensorexpr/tests.h +++ b/test/cpp/tensorexpr/tests.h @@ -312,6 +312,10 @@ namespace jit { _(BoundsInferenceMultipleTopLoopStore) \ _(BoundsInferenceCacheReads) \ _(BoundsInferenceFlattened) \ + _(GetPotentialHazards) \ + _(GetPotentialHazardsLoopNoHazard) \ + _(GetPotentialHazardsLoopCall) \ + _(GetPotentialHazardsLoopSplit) \ _(BoundOverlap) \ _(BoundOverlapSymbolic) \ _(BoundOverlapMultiDim) \ diff --git a/torch/csrc/jit/tensorexpr/bounds_inference.cpp b/torch/csrc/jit/tensorexpr/bounds_inference.cpp index 543643a3c7ec..2424c2dfc45a 100644 --- a/torch/csrc/jit/tensorexpr/bounds_inference.cpp +++ b/torch/csrc/jit/tensorexpr/bounds_inference.cpp @@ -12,8 +12,9 @@ namespace tensorexpr { using namespace analysis; +template BoundsInfo mergeTensorAccesses( - const std::deque>& accesses, + const Container& accesses, const std::unordered_map& varToBuf, bool distinctAccessKinds) { BoundsInfo ret; @@ -90,6 +91,14 @@ BoundsInfo inferBounds(Stmt* s, bool distinctAccessKinds) { checker.getHistory(), varToBuf, distinctAccessKinds); } +BoundsInfo getInferredBounds( + MemDependencyChecker& analyzer, + Stmt* s, + bool distinctAccessKinds) { + return mergeTensorAccesses( + analyzer.accessesWithin(s), getAllBufs(s), distinctAccessKinds); +} + void printBoundsInfo(const BoundsInfo& v) { std::cerr << "Access vector {\n"; for (auto& pair : v) { @@ -166,6 +175,87 @@ std::vector getBoundExtents( return extents; } +using BoundSet = std::unordered_set; + +BoundSet convertBounds( + const std::vector& bounds, + TensorAccessKind filter = kMutate) { + BoundSet ret; + for (auto& TABI : bounds) { + if (filter == kMutate || TABI.kind == filter) { + for (size_t i = 0; i < TABI.start.size(); ++i) { + ret.insert(Bound(TABI.start[i], TABI.stop[i])); + } + } + } + return ret; +} + +BoundSet convertBounds( + BoundsInfo& bounds, + const Buf* buf, + TensorAccessKind filter = kMutate) { + auto it = bounds.find(buf); + if (it == bounds.end()) { + return BoundSet(); + } + + return convertBounds(it->second, filter); +} + +HazardKind getPotentialHazards( + MemDependencyChecker& analyzer, + Stmt* A, + Stmt* B) { + BoundsInfo aBounds = getInferredBounds(analyzer, A, true); + BoundsInfo bBounds = getInferredBounds(analyzer, B, true); + + BoundSet aWrites; + BoundSet aReads; + + for (auto& pair : bBounds) { + const Buf* buf = pair.first; + if (aBounds.find(buf) == aBounds.end()) { + continue; + } + + auto aWrites = convertBounds(aBounds, buf, kStore); + auto aReads = convertBounds(aBounds, buf, kLoad); + + auto bWrites = convertBounds(pair.second, kStore); + auto bReads = convertBounds(pair.second, kLoad); + + // First, RAW. + for (auto& bR : bReads) { + for (auto& aW : aWrites) { + if (boundOverlap(bR, aW) != NoOverlap) { + return HazardKind::ReadAfterWrite; + } + } + } + + // Then WAR. + for (auto& bW : bWrites) { + for (auto& aR : aReads) { + if (boundOverlap(bW, aR) != NoOverlap) { + return HazardKind::WriteAfterRead; + } + } + } + + // Then WAW. + for (auto& bW : bWrites) { + for (auto& aW : aWrites) { + if (boundOverlap(bW, aW) != NoOverlap) { + return HazardKind::WriteAfterWrite; + } + } + } + } + + return HazardKind::NoDependency; +} + } // namespace tensorexpr } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/tensorexpr/bounds_inference.h b/torch/csrc/jit/tensorexpr/bounds_inference.h index 72cc6dc044e1..b5b58d09f0b6 100644 --- a/torch/csrc/jit/tensorexpr/bounds_inference.h +++ b/torch/csrc/jit/tensorexpr/bounds_inference.h @@ -28,11 +28,28 @@ using BoundsInfo = TORCH_API BoundsInfo inferBounds(Stmt* s, bool distinctAccessKinds = true); +// Bounds inference caching the analysis. The MemDependencyChecker must already +// have been run. +TORCH_API BoundsInfo getInferredBounds( + analysis::MemDependencyChecker& analyzer, + Stmt* s, + bool distinctAccessKinds = true); + TORCH_API void printBoundsInfo(const BoundsInfo& v); TORCH_API std::vector getBoundExtents( const std::vector& infos); +// The kind of dependency found, in increasing order of exclusivity. +enum class HazardKind { + ReadAfterWrite, + WriteAfterRead, + WriteAfterWrite, + NoDependency, +}; +TORCH_API HazardKind +getPotentialHazards(analysis::MemDependencyChecker& analyzer, Stmt* A, Stmt* B); + } // namespace tensorexpr } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/tensorexpr/bounds_overlap.h b/torch/csrc/jit/tensorexpr/bounds_overlap.h index 0b3ef3faa3ef..1ae4278c4b2c 100644 --- a/torch/csrc/jit/tensorexpr/bounds_overlap.h +++ b/torch/csrc/jit/tensorexpr/bounds_overlap.h @@ -32,12 +32,22 @@ struct TORCH_API Bound { return exprEquals(start, other.start) && exprEquals(end, other.end); } + bool operator==(const Bound& other) const { + return exprEquals(start, other.start) && exprEquals(end, other.end); + } + void swap() { std::swap(start, end); swapped = !swapped; } }; +struct BoundHash { + size_t operator()(const Bound& b) const { + return std::hash()(b.start) ^ std::hash()(b.end); + } +}; + // The type of overlap found. Each condition is true only if none of the // previous conditions hold. // ContainedOrEqual: All elements in the Bound A are in the Bound B (this diff --git a/torch/csrc/jit/tensorexpr/mem_dependency_checker.cpp b/torch/csrc/jit/tensorexpr/mem_dependency_checker.cpp index 43a15c2d4a5e..49faa865c612 100644 --- a/torch/csrc/jit/tensorexpr/mem_dependency_checker.cpp +++ b/torch/csrc/jit/tensorexpr/mem_dependency_checker.cpp @@ -276,7 +276,7 @@ bool MemDependencyChecker::allowLoopExecutionOrderAnalysis(bool allow) { return allow; } -const std::deque>& MemDependencyChecker:: +const std::vector>& MemDependencyChecker:: getHistory() const { return currentScope_->accesses_; } @@ -452,6 +452,27 @@ std::shared_ptr MemDependencyChecker::accessFor( return nullptr; } +std::unordered_set> MemDependencyChecker:: + accessesWithin(const Stmt* A) const { + auto it = scopeToAccesses_.find(A); + if (it != scopeToAccesses_.end()) { + return std::unordered_set>( + it->second.begin(), it->second.end()); + } + + std::unordered_set> ret; + auto bound = stmtToAccess_.equal_range(A); + for (auto it = bound.first; it != bound.second; ++it) { + ret.insert(it->second); + } + return ret; +} + +std::unordered_set> MemDependencyChecker:: + accessesWithin(const Expr* A) const { + return {accessFor(A)}; +} + std::shared_ptr MemDependencyChecker::input(const Buf* b) const { auto it = inputs_.find(b); if (it == inputs_.end()) { @@ -924,6 +945,19 @@ void MemDependencyChecker::visit(const For* v) { } } + std::vector> mergedAccesses; + mergedAccesses.reserve( + extentsScope->accesses_.size() + currentScope_->accesses_.size()); + std::copy( + extentsScope->accesses_.begin(), + extentsScope->accesses_.end(), + std::back_inserter(mergedAccesses)); + std::copy( + currentScope_->accesses_.begin(), + currentScope_->accesses_.end(), + std::back_inserter(mergedAccesses)); + scopeToAccesses_.emplace(v, mergedAccesses); + // it's a little faster to merge without closing, and since no writes can // occur within the start and stop exprs we'll do that. mergeScope(extentsScope, extentsScope->parent, false); @@ -935,13 +969,15 @@ void MemDependencyChecker::visit(const Cond* v) { const Stmt* last = lastStmt_; lastStmt_ = v; + auto enclosingScope = + std::make_shared(currentScope_->block, currentScope_); + // condition is in enclosing scope. v->condition()->accept(this); Block* true_stmt = v->true_stmt(); Block* false_stmt = v->false_stmt(); - auto enclosingScope = currentScope_; // Create scopes so the Block visitor doesn't create and merge a new scope. auto trueScope = std::make_shared(true_stmt, enclosingScope); auto falseScope = std::make_shared(false_stmt, enclosingScope); @@ -968,7 +1004,13 @@ void MemDependencyChecker::visit(const Cond* v) { mergeScope(trueScope, enclosingScope, false); mergeScope(falseScope, enclosingScope, false); + // Merge the enclosing scope into it's parent. + mergeScope(enclosingScope, enclosingScope->parent, false); + currentScope_ = enclosingScope; + scopeToAccesses_.emplace(v, enclosingScope->accesses_); + + currentScope_ = enclosingScope->parent; lastStmt_ = last; } @@ -1091,6 +1133,8 @@ void MemDependencyChecker::visit(const Block* v) { knownVarBounds_[pair.first] = pair.second; } + scopeToAccesses_.emplace(v, currentScope_->accesses_); + if (currentScope_ != prev_scope) { mergeScope(currentScope_, prev_scope, true); currentScope_ = prev_scope; diff --git a/torch/csrc/jit/tensorexpr/mem_dependency_checker.h b/torch/csrc/jit/tensorexpr/mem_dependency_checker.h index b799b185e1d9..a1bb91fa17ad 100644 --- a/torch/csrc/jit/tensorexpr/mem_dependency_checker.h +++ b/torch/csrc/jit/tensorexpr/mem_dependency_checker.h @@ -1,7 +1,6 @@ #pragma once #include #include -#include #include #include @@ -221,16 +220,25 @@ class TORCH_API MemDependencyChecker : public IRVisitor { const std::shared_ptr& A, const std::shared_ptr& B); - // Retuns the AccessInfo + // Returns the AccessInfo std::shared_ptr accessFor(const Stmt* A) const; std::shared_ptr accessFor(const Expr* A) const; + // Returns all AccessInfos. + std::unordered_set> accessesWithin( + const Stmt* A) const; + // TODO: this will return only the AccessInfo for A. It's included for + // completeness but be aware it wont return accesses used in the computation + // of A. + std::unordered_set> accessesWithin( + const Expr* A) const; + // Accesses relating to input and output buffers. std::shared_ptr input(const Buf* B) const; std::shared_ptr output(const Buf* B) const; // Returns the full history of reads and writes. - const std::deque>& getHistory() const; + const std::vector>& getHistory() const; // Dumps the dependency graph in DOT format. void dumpDAG(const std::string& filename) const; @@ -273,7 +281,7 @@ class TORCH_API MemDependencyChecker : public IRVisitor { std::unordered_map shadowedVarBounds; std::unordered_set localVars; - std::deque> accesses_; + std::vector> accesses_; std::unordered_map> openWrites_; }; @@ -285,6 +293,8 @@ class TORCH_API MemDependencyChecker : public IRVisitor { stmtToAccess_; std::unordered_multimap> exprToAccess_; + std::unordered_map>> + scopeToAccesses_; VarBoundMap knownVarBounds_; @@ -303,7 +313,8 @@ class TORCH_API MemDependencyChecker : public IRVisitor { } }; - // Look for and insert accesses belonging to all nodes that act like reads. + // Look for and insert accesses belonging to all nodes that act like + // reads. insertAllReads(NodeFinder::find(v)); insertAllReads(NodeFinder::find(v)); insertAllReads(NodeFinder::find(v)); From f51be328ae06e39d7a760194a6e646eef050e92c Mon Sep 17 00:00:00 2001 From: James Reed Date: Thu, 12 Nov 2020 11:32:33 -0800 Subject: [PATCH 70/93] [FX] Fix __tensor_constants not scriptable (#47817) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47817 Test Plan: Imported from OSS Reviewed By: nikithamalgifb Differential Revision: D24908959 Pulled By: jamesr66a fbshipit-source-id: c0cadae2091e917b72684262b8655f8813ac9d91 --- test/test_fx.py | 12 ++++++++++++ torch/fx/symbolic_trace.py | 2 +- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/test/test_fx.py b/test/test_fx.py index b035b37663d4..b207f11b1d80 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -631,6 +631,18 @@ def forward(self, x : dict): with self.assertRaisesRegex(TraceError, 'Proxy object cannot be iterated.'): symbolic_trace(ud) + def test_script_tensor_constant(self): + # TorchScript seems to ignore attributes that start with `__`. + # We used to call anonymous Tensor values `__tensor_constant*`, but + # they were getting ignored by script. Now they're called + # `_tensor_constant*` + class IHaveATensorConstant(torch.nn.Module): + def forward(self, x): + return x + torch.rand(3, 4) + + traced = torch.fx.symbolic_trace(IHaveATensorConstant()) + torch.jit.script(traced) + def test_torch_custom_ops(self): class M(torch.nn.Module): def forward(self, a): diff --git a/torch/fx/symbolic_trace.py b/torch/fx/symbolic_trace.py index 13ffc2cb0100..b2e5b0961114 100644 --- a/torch/fx/symbolic_trace.py +++ b/torch/fx/symbolic_trace.py @@ -94,7 +94,7 @@ def create_arg(self, a: Any) -> Argument: if not qualname: i = 0 while True: - qualname = f'__tensor_constant{i}' + qualname = f'_tensor_constant{i}' if not hasattr(self.root, qualname): break i += 1 From 6aaf04616b99276f72b48bd3cdf83be3a4b3296b Mon Sep 17 00:00:00 2001 From: Tao Xu Date: Thu, 12 Nov 2020 11:52:49 -0800 Subject: [PATCH 71/93] [Metal] Remove undefined tests Summary: As title Test Plan: - Circle CI - Sandcastle Reviewed By: husthyc Differential Revision: D24915370 fbshipit-source-id: fe05ac37a25c804695a13fb5a7eabbc60442a102 --- aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.h | 1 - 1 file changed, 1 deletion(-) diff --git a/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.h b/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.h index aa387f44d765..105f013da8e0 100644 --- a/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.h +++ b/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.h @@ -26,7 +26,6 @@ bool test_upsampling_nearest2d_vec(); bool test_adaptive_avg_pool2d(); bool test_hardtanh_(); bool test_reshape(); -bool test_mobilenetv2(); } // namespace metal } // namespace native From 275a89a7ee45772b5d353ec488ec2cd0c78a69c3 Mon Sep 17 00:00:00 2001 From: Omkar Salpekar Date: Thu, 12 Nov 2020 12:12:00 -0800 Subject: [PATCH 72/93] [Docs] Store Docs fixes about HashStore API (#47643) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47643 Updating the docs to indicate the `num_keys` and `delete_key` APIs are now supported by the HashStore (not just TCPStore). ghstack-source-id: 116459958 Test Plan: CI Reviewed By: jiayisuse, mrshenli Differential Revision: D24633570 fbshipit-source-id: 549479dd99f9ec6decbfffcb74b9792403d05ba2 --- torch/csrc/distributed/c10d/init.cpp | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index f17802c6f0b1..13cc832d28bd 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -440,8 +440,8 @@ Deletes the key-value pair associated with ``key`` from the store. Returns `true` if the key was successfully deleted, and `false` if it was not. .. warning:: - The ``delete_key`` API is only supported by the :class:`~torch.distributed.TCPStore`. Using this API - with the :class:`~torch.distributed.FileStore` or :class:`~torch.distributed.HashStore` will result in an exception. + The ``delete_key`` API is only supported by the :class:`~torch.distributed.TCPStore` and :class:`~torch.distributed.HashStore`. Using this API + with the :class:`~torch.distributed.FileStore` will result in an exception. Arguments: key (str): The key to be deleted from the store @@ -451,6 +451,7 @@ Deletes the key-value pair associated with ``key`` from the store. Returns Example:: >>> import torch.distributed as dist + >>> # Using TCPStore as an example, HashStore can also be used >>> store = dist.TCPStore("127.0.0.1", 0, true, timedelta(seconds=30)) >>> store.set("first_key") >>> # This should return true @@ -469,14 +470,15 @@ and :meth:`~torch.distributed.store.add` since one key is used to coordinate all the workers using the store. .. warning:: - The ``num_keys`` API is only supported by the :class:`~torch.distributed.TCPStore`. Using this API - with the :class:`~torch.distributed.FileStore` or :class:`~torch.distributed.HashStore` will result in an exception. + The ``num_keys`` API is only supported by the :class:`~torch.distributed.TCPStore` and :class:`~torch.distributed.HashStore`. Using this API + with the :class:`~torch.distributed.FileStore` will result in an exception. Returns: The number of keys present in the store. Example:: >>> import torch.distributed as dist + >>> # Using TCPStore as an example, HashStore can also be used >>> store = dist.TCPStore("127.0.0.1", 0, true, timedelta(seconds=30)) >>> store.set("first_key", "first_value") >>> # This should return 2 From 149190c01453b33f1150556b9c5d2df3aad287d8 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Thu, 12 Nov 2020 12:12:24 -0800 Subject: [PATCH 73/93] Added CUDA support for complex input for torch.solve (#47045) Summary: `torch.solve` now works for complex inputs on GPU. I moved the existing tests to `test_linalg.py` and modified them to test complex and float32 dtypes. Differentiation also works correctly with complex inputs. Fixes https://github.com/pytorch/pytorch/issues/41084 Ref. https://github.com/pytorch/pytorch/issues/33152 anjali411 I hope you don't mind that I took over https://github.com/pytorch/pytorch/pull/42737 Pull Request resolved: https://github.com/pytorch/pytorch/pull/47045 Reviewed By: nikithamalgifb Differential Revision: D24921503 Pulled By: anjali411 fbshipit-source-id: 4c3fc4f193a84b6e28c43c08672d480715000923 --- .../ATen/native/cuda/BatchLinearAlgebra.cu | 46 +++++++- test/test_autograd.py | 6 +- test/test_linalg.py | 105 +++++++++++++++--- test/test_torch.py | 81 -------------- tools/autograd/gen_variable_type.py | 2 +- torch/_torch_docs.py | 2 + torch/csrc/autograd/FunctionsManual.cpp | 6 +- torch/linalg/__init__.py | 2 +- .../_internal/common_methods_invocations.py | 20 +++- 9 files changed, 157 insertions(+), 113 deletions(-) diff --git a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu index 5379d38fa43f..4f9ff63d0ece 100644 --- a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu +++ b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu @@ -160,6 +160,28 @@ void magmaSolve( AT_CUDA_CHECK(cudaGetLastError()); } +template<> +void magmaSolve>( + magma_int_t n, magma_int_t nrhs, c10::complex* dA, magma_int_t ldda, + magma_int_t* ipiv, c10::complex* dB, magma_int_t lddb, magma_int_t* info) { + MagmaStreamSyncGuard guard; + magma_zgesv_gpu(n, nrhs, + reinterpret_cast(dA), ldda, ipiv, + reinterpret_cast(dB), lddb, info); + AT_CUDA_CHECK(cudaGetLastError()); +} + +template<> +void magmaSolve>( + magma_int_t n, magma_int_t nrhs, c10::complex* dA, magma_int_t ldda, + magma_int_t* ipiv, c10::complex* dB, magma_int_t lddb, magma_int_t* info) { + MagmaStreamSyncGuard guard; + magma_cgesv_gpu(n, nrhs, + reinterpret_cast(dA), ldda, ipiv, + reinterpret_cast(dB), lddb, info); + AT_CUDA_CHECK(cudaGetLastError()); +} + template<> void magmaSolveBatched( magma_int_t n, magma_int_t nrhs, double** dA_array, magma_int_t ldda, @@ -178,6 +200,28 @@ void magmaSolveBatched( AT_CUDA_CHECK(cudaGetLastError()); } +template<> +void magmaSolveBatched>( + magma_int_t n, magma_int_t nrhs, c10::complex** dA_array, magma_int_t ldda, + magma_int_t** dipiv_array, c10::complex** dB_array, magma_int_t lddb, + magma_int_t* dinfo_array, magma_int_t batch_count, const MAGMAQueue& magma_queue) { + magma_zgesv_batched(n, nrhs, + reinterpret_cast(dA_array), ldda, dipiv_array, + reinterpret_cast(dB_array), lddb, dinfo_array, batch_count, magma_queue.get_queue()); + AT_CUDA_CHECK(cudaGetLastError()); +} + +template<> +void magmaSolveBatched>( + magma_int_t n, magma_int_t nrhs, c10::complex** dA_array, magma_int_t ldda, + magma_int_t** dipiv_array, c10::complex** dB_array, magma_int_t lddb, + magma_int_t* dinfo_array, magma_int_t batch_count, const MAGMAQueue& magma_queue) { + magma_cgesv_batched(n, nrhs, + reinterpret_cast(dA_array), ldda, dipiv_array, + reinterpret_cast(dB_array), lddb, dinfo_array, batch_count, magma_queue.get_queue()); + AT_CUDA_CHECK(cudaGetLastError()); +} + template<> void magmaLu( magma_int_t m, magma_int_t n, double* dA, magma_int_t ldda, @@ -1059,7 +1103,7 @@ std::tuple _solve_helper_cuda(const Tensor& self, const Tensor& auto self_working_copy = cloneBatchedColumnMajor(self); auto A_working_copy = cloneBatchedColumnMajor(A); std::vector infos(batchCount(self), 0); - AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "solve_cuda", [&]{ + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "solve_cuda", [&]{ apply_solve(self_working_copy, A_working_copy, infos); }); if (self.dim() > 2) { diff --git a/test/test_autograd.py b/test/test_autograd.py index 34ed8b72867e..d7948b56f27d 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -4987,10 +4987,12 @@ def run_functional_checks(test_case, test_name, name, apply_fn, run_grad_checks, 'eq_', 'ne_', 'add', '__radd__', 'sum', 'conj', 'sin', 'cos', 'mul', 'sinh', 'cosh', '__rmul__', 'sgn', 'abs', 'dot', 'vdot', 'tensor_split', 'matmul', 'bmm', 'mv', 'ger', 'diagonal', 'atan', 'angle', 'tanh', 'fill_', 'sub', - 'exp', 'mean', 'inverse', 'triangular_solve'] + separate_complex_tests + 'exp', 'mean', 'inverse', 'triangular_solve', 'solve'] + separate_complex_tests # this list corresponds to cases that are not currently implemented -skip_cuda_list = ['bmm_complex', 'matmul_4d_4d_complex', 'inverse_batched_complex'] +skip_cuda_list = ['bmm_complex', 'matmul_4d_4d_complex', 'inverse_batched_complex', + 'solve_batched_broadcast_A_complex', 'solve_batched_broadcast_b_complex', + 'solve_batched_complex', 'solve_batched_dims_complex'] def add_test( name, diff --git a/test/test_linalg.py b/test/test_linalg.py index 67bf450a006c..16cfea8f536f 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -8,8 +8,8 @@ from torch.testing._internal.common_utils import \ (TestCase, run_tests, TEST_NUMPY, TEST_SCIPY, IS_MACOS, IS_WINDOWS, slowTest, TEST_WITH_ASAN, make_tensor) from torch.testing._internal.common_device_type import \ - (instantiate_device_type_tests, dtypes, dtypesIfCUDA, - onlyCUDA, onlyCPU, skipCUDAIfNoMagma, skipCPUIfNoLapack, precisionOverride, + (instantiate_device_type_tests, dtypes, + onlyCPU, skipCUDAIfNoMagma, skipCPUIfNoLapack, precisionOverride, skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfRocm, onlyOnCPUAndCUDA) from torch.testing._internal.jit_metaprogramming_utils import gen_script_fn_and_args from torch.autograd import gradcheck, gradgradcheck @@ -1134,10 +1134,94 @@ def run_test_singular_input(batch_dim, n): for params in [(1, 0), (2, 0), (2, 1), (4, 0), (4, 2), (10, 2)]: run_test_singular_input(*params) + def solve_test_helper(self, A_dims, b_dims, device, dtype): + from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value + + b = torch.randn(*b_dims, dtype=dtype, device=device) + A = random_fullrank_matrix_distinct_singular_value(*A_dims, dtype=dtype).to(device) + return b, A + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + def test_solve(self, device, dtype): + for (k, n) in zip([2, 3, 5], [3, 5, 7]): + b, A = self.solve_test_helper((n,), (n, k), device, dtype) + x = torch.solve(b, A)[0] + self.assertEqual(b, A.mm(x)) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + def test_solve_batched(self, device, dtype): + def solve_batch_helper(A_dims, b_dims): + b, A = self.solve_test_helper(A_dims, b_dims, device, dtype) + x_exp_list = [] + for i in range(b_dims[0]): + x_exp_list.append(torch.solve(b[i], A[i])[0]) + x_exp = torch.stack(x_exp_list) # Stacked output + x_act = torch.solve(b, A)[0] # Actual output + self.assertEqual(x_exp, x_act) # Equality check + # TODO(@ivanyashchuk): remove this once batched matmul is avaiable on CUDA for complex dtypes + if self.device_type == 'cuda' and dtype.is_complex: + Ax = torch.matmul(A.cpu(), x_act.cpu()).to(device) + else: + Ax = torch.matmul(A, x_act) + self.assertEqual(b, Ax) + + for batchsize in [1, 3, 4]: + solve_batch_helper((5, batchsize), (batchsize, 5, 10)) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + def test_solve_batched_non_contiguous(self, device, dtype): + from numpy.linalg import solve + from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value + A = random_fullrank_matrix_distinct_singular_value(2, 2, dtype=dtype).to(device).permute(1, 0, 2) + b = torch.randn(2, 2, 2, dtype=dtype, device=device).permute(2, 1, 0) + x, _ = torch.solve(b, A) + x_exp = solve(A.cpu().numpy(), b.cpu().numpy()) + self.assertEqual(x, x_exp) + + @slowTest + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + def test_solve_batched_many_batches(self, device, dtype): + for A_dims, b_dims in zip([(5, 256, 256), (3, )], [(5, 1), (512, 512, 3, 1)]): + b, A = self.solve_test_helper(A_dims, b_dims, device, dtype) + x, _ = torch.solve(b, A) + # TODO(@ivanyashchuk): remove this once batched matmul is avaiable on CUDA for complex dtypes + if self.device_type == 'cuda' and dtype.is_complex: + Ax = torch.matmul(A.cpu(), x.cpu()).to(device) + else: + Ax = torch.matmul(A, x) + self.assertEqual(Ax, b.expand_as(x)) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + def test_solve_batched_broadcasting(self, device, dtype): + from numpy.linalg import solve + + def run_test(A_dims, b_dims): + A_matrix_size = A_dims[-1] + A_batch_dims = A_dims[:-2] + b, A = self.solve_test_helper((A_matrix_size,) + A_batch_dims, b_dims, device, dtype) + x, _ = torch.solve(b, A) + x_exp = solve(A.cpu().numpy(), b.cpu().numpy()) + self.assertEqual(x, x_exp) + + # test against numpy.linalg.solve + run_test((2, 1, 3, 4, 4), (2, 1, 3, 4, 6)) # no broadcasting + run_test((2, 1, 3, 4, 4), (4, 6)) # broadcasting b + run_test((4, 4), (2, 1, 3, 4, 2)) # broadcasting A + run_test((1, 3, 1, 4, 4), (2, 1, 3, 4, 5)) # broadcasting A & b + @skipCUDAIfNoMagma @skipCPUIfNoLapack @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) - @dtypesIfCUDA(torch.float, torch.double) @precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}) def test_tensorsolve(self, device, dtype): def run_test(a_shape, dims): @@ -1161,7 +1245,6 @@ def run_test(a_shape, dims): @skipCUDAIfNoMagma @skipCPUIfNoLapack @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) - @dtypesIfCUDA(torch.float, torch.double) def test_tensorsolve_empty(self, device, dtype): # Check for empty inputs. NumPy does not work for these cases. a = torch.empty(0, 0, 1, 2, 3, 0, dtype=dtype, device=device) @@ -1169,23 +1252,9 @@ def test_tensorsolve_empty(self, device, dtype): x = torch.linalg.tensorsolve(a, b) self.assertEqual(torch.tensordot(a, x, dims=len(x.shape)), b) - # TODO: once "solve_cuda" supports complex dtypes, they shall be added to above tests - @unittest.expectedFailure - @onlyCUDA - @skipCUDAIfNoMagma - @dtypes(torch.cfloat, torch.cdouble) - def test_tensorsolve_xfailed(self, device, dtype): - a_shape = (2, 3, 6) - a = torch.randn(a_shape, dtype=dtype, device=device) - b = torch.randn(a_shape[:2], dtype=dtype, device=device) - result = torch.linalg.tensorsolve(a, b) - expected = np.linalg.tensorsolve(a.cpu().numpy(), b.cpu().numpy()) - self.assertEqual(result, expected) - @skipCUDAIfNoMagma @skipCPUIfNoLapack @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) - @dtypesIfCUDA(torch.float, torch.double) @precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}) def test_tensorsolve_non_contiguous(self, device, dtype): def run_test_permuted(a_shape, dims): diff --git a/test/test_torch.py b/test/test_torch.py index 1067c56375eb..6bcbd5582dc8 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -7613,87 +7613,6 @@ def run_test(matsize, batchdims, mat_chars): run_test(matsize, batchdims, mat_chars=['sym', 'sym_pd', 'sym_psd']) run_test(matsize, batchdims, mat_chars=['sing', 'non_sing']) - def solve_test_helper(self, A_dims, b_dims, device, dtype): - from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value - - b = torch.randn(*b_dims, dtype=dtype, device=device) - A = random_fullrank_matrix_distinct_singular_value(*A_dims, dtype=dtype, device=device) - return b, A - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(torch.double) - def test_solve(self, device, dtype): - for (k, n) in zip([2, 3, 5], [3, 5, 7]): - b, A = self.solve_test_helper((n,), (n, k), device, dtype) - x = torch.solve(b, A)[0] - self.assertLessEqual(b.dist(A.mm(x)), 1e-12) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(torch.double) - def test_solve_batched(self, device, dtype): - def solve_batch_helper(A_dims, b_dims): - b, A = self.solve_test_helper(A_dims, b_dims, device, dtype) - x_exp_list = [] - for i in range(b_dims[0]): - x_exp_list.append(torch.solve(b[i], A[i])[0]) - x_exp = torch.stack(x_exp_list) # Stacked output - x_act = torch.solve(b, A)[0] # Actual output - self.assertEqual(x_exp, x_act) # Equality check - self.assertLessEqual(b.dist(torch.matmul(A, x_act)), 1e-12) # Correctness check - - for batchsize in [1, 3, 4]: - solve_batch_helper((5, batchsize), (batchsize, 5, 10)) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @unittest.skipIf(not TEST_NUMPY, "NumPy not found") - @dtypes(torch.double) - def test_solve_batched_non_contiguous(self, device, dtype): - from numpy.linalg import solve - from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value - A = random_fullrank_matrix_distinct_singular_value(2, 2, dtype=dtype, - device=device).permute(1, 0, 2) - b = torch.randn(2, 2, 2, dtype=dtype, device=device).permute(2, 1, 0) - x, _ = torch.solve(b, A) - x_exp = torch.Tensor(solve(A.cpu().numpy(), b.cpu().numpy())).to(dtype=dtype, device=device) - self.assertEqual(x, x_exp) - - @slowTest - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(torch.double) - def test_solve_batched_many_batches(self, device, dtype): - b, A = self.solve_test_helper((5, 256, 256), (5, 1), device, dtype) - x, _ = torch.solve(b, A) - self.assertEqual(torch.matmul(A, x), b.expand(A.shape[:-2] + (5, 1))) - - b, A = self.solve_test_helper((3,), (512, 512, 3, 1), device, dtype) - x, _ = torch.solve(b, A) - self.assertEqual(torch.matmul(A, x), b) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @unittest.skipIf(not TEST_NUMPY, "NumPy not found") - @dtypes(torch.double) - def test_solve_batched_broadcasting(self, device, dtype): - from numpy.linalg import solve - - def run_test(A_dims, b_dims): - A_matrix_size = A_dims[-1] - A_batch_dims = A_dims[:-2] - b, A = self.solve_test_helper((A_matrix_size,) + A_batch_dims, b_dims, device, dtype) - x, _ = torch.solve(b, A) - x_exp = torch.Tensor(solve(A.cpu().numpy(), b.cpu().numpy())).to(dtype=dtype, device=device) - self.assertEqual(x, x_exp) - - # test against numpy.linalg.solve - run_test((2, 1, 3, 4, 4), (2, 1, 3, 4, 6)) # no broadcasting - run_test((2, 1, 3, 4, 4), (4, 6)) # broadcasting b - run_test((4, 4), (2, 1, 3, 4, 2)) # broadcasting A - run_test((1, 3, 1, 4, 4), (2, 1, 3, 4, 5)) # broadcasting A & b - def cholesky_solve_test_helper(self, A_dims, b_dims, upper, device, dtype): from torch.testing._internal.common_utils import random_symmetric_pd_matrix diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index b9215a66b098..cb2df20492cd 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -94,7 +94,7 @@ 'dot', 'vdot', 'cholesky', 'triangular_solve', 'mm', '_unsafe_view', 'mv', 'ger', 'bmm', 'diagonal', 'alias', 'atan', 'log', 'log10', 'log1p', 'log2', 'reciprocal', 'tan', 'pow', 'rsqrt', 'tanh', 'tanh_backward', 'asinh', 'acosh', 'take', 'fill_', - 'exp', 'nonzero', 'mean', 'inverse' + 'exp', 'nonzero', 'mean', 'inverse', 'solve' } # Some operators invalidate the grad_accumulator. Let's reset it. diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index afc71cff8304..6755153d88e6 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -3225,6 +3225,8 @@ def merge_dicts(*dicts): batches of 2D matrices. If the inputs are batches, then returns batched outputs `solution, LU`. +Supports real-valued and complex-valued inputs. + .. note:: Irrespective of the original strides, the returned matrices diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index 48ef1a796b50..674c5bbbe607 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -422,15 +422,15 @@ Tensor prod_backward(Tensor grad, const Tensor& input, Tensor result, int64_t di } Tensor solve_backward_self(const Tensor & grad, const Tensor & self, const Tensor & A) { - return std::get<0>(at::solve(grad, A.transpose(-2, -1))); + return std::get<0>(at::solve(grad, A.conj().transpose(-2, -1))); } Tensor solve_backward_A(const Tensor & grad, const Tensor & self, const Tensor & A, const Tensor & solution) { Tensor grad_self = solve_backward_self(grad, self, A); if (self.ndimension() == 2 && A.ndimension() == 2) { - return -at::mm(grad_self, solution.transpose(-2, -1)); + return -at::mm(grad_self, solution.conj().transpose(-2, -1)); } - return -at::matmul(grad_self, solution.transpose(-2, -1)); + return -at::matmul(grad_self, solution.conj().transpose(-2, -1)); } Tensor cumsum_backward(const Tensor & x, int64_t dim) { diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index ad0badf5eed9..bf2947da81c8 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -146,7 +146,7 @@ Computes a tensor ``x`` such that ``tensordot(input, x, dims=x.ndim) = other``. The resulting tensor ``x`` has the same shape as ``input[other.ndim:]``. -Supports real-valued and, only on the CPU, complex-valued inputs. +Supports real-valued and complex-valued inputs. .. note:: If :attr:`input` does not satisfy the requirement ``prod(input.shape[other.ndim:]) == prod(input.shape[:other.ndim])`` diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 5d4b68178416..e62d62eb12bf 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -1325,15 +1325,23 @@ def method_tests(): ('lu', (3, S, S), (True, True), 'square_batch_with_info', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), ('lu', (3, 3, S, S), (True, False), 'square_many_batches_no_info', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), ('lu', (3, 3, S, S), (True, True), 'square_many_batches_with_info', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('solve', (S, S), (random_fullrank_matrix_distinct_singular_value( - S, silent=True),), '', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('solve', (S, S, S), (random_fullrank_matrix_distinct_singular_value(S, S, silent=True),), + ('solve', (S, S), (lambda dtype, device: random_fullrank_matrix_distinct_singular_value( + S, silent=True, dtype=dtype, device=device),), '', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), + ('solve', (S, S, S), + (lambda dtype, device: + random_fullrank_matrix_distinct_singular_value(S, S, silent=True, dtype=dtype, device=device),), 'batched', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('solve', (2, 3, S, S), (random_fullrank_matrix_distinct_singular_value(S, 2, 3, silent=True),), + ('solve', (2, 3, S, S), + (lambda dtype, device: + random_fullrank_matrix_distinct_singular_value(S, 2, 3, silent=True, dtype=dtype, device=device),), 'batched_dims', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('solve', (2, 2, S, S), (random_fullrank_matrix_distinct_singular_value(S, 1, silent=True),), + ('solve', (2, 2, S, S), + (lambda dtype, device: + random_fullrank_matrix_distinct_singular_value(S, 1, silent=True, dtype=dtype, device=device),), 'batched_broadcast_A', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('solve', (1, S, S), (random_fullrank_matrix_distinct_singular_value(S, 2, 2, silent=True),), + ('solve', (1, S, S), + (lambda dtype, device: + random_fullrank_matrix_distinct_singular_value(S, 2, 2, silent=True, dtype=dtype, device=device),), 'batched_broadcast_b', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), ('fill_', (S, S, S), (1,), 'number'), ('fill_', (), (1,), 'number_scalar'), From b1a4170ab3add6347f69ae11ffb43b0345ba9d78 Mon Sep 17 00:00:00 2001 From: Nick Gibson Date: Thu, 12 Nov 2020 12:31:19 -0800 Subject: [PATCH 74/93] [NNC] Fix lowering of aten::pow (#47795) Summary: NNC lowering of aten::pow assumes that the types of the exponent is either float or int cast to to float, which doesn't work great with double (or half for that matter). Fixes https://github.com/pytorch/pytorch/issues/47304 Pull Request resolved: https://github.com/pytorch/pytorch/pull/47795 Reviewed By: ZolotukhinM Differential Revision: D24904201 Pulled By: nickgg fbshipit-source-id: 43c3ea704399ebb36c33cd222db16c60e5b7ada5 --- test/test_tensorexpr.py | 15 ++++++ torch/csrc/jit/tensorexpr/kernel.cpp | 68 +++++++++------------------- 2 files changed, 37 insertions(+), 46 deletions(-) diff --git a/test/test_tensorexpr.py b/test/test_tensorexpr.py index b673e0233a0d..7797d58b4cc8 100644 --- a/test/test_tensorexpr.py +++ b/test/test_tensorexpr.py @@ -1307,6 +1307,21 @@ def bias_gelu(bias, y): x = warmup_and_run_forward(traced, a, b) self.assertLastGraphAllFused() + def test_exp_pow(self): + devices = ["cuda", "cpu"] if torch.cuda.is_available() else ["cpu"] + + @torch.jit.script + def do_exp(x, y, z): + return ((x * y) * 2) * torch.pow(z, 2) + + for device in devices: + x = torch.rand(10, dtype=torch.double, device=device) + y = torch.rand(10, dtype=torch.double, device=device) + z = torch.rand(10, dtype=torch.double, device=device) + traced = torch.jit.trace(do_exp, (x, y, z)) + x = warmup_and_run_forward(traced, x, y, z) + self.assertLastGraphAllFused() + def test_transpose(self): @torch.jit.script def test(x, y, z): diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp index bdd74cefc9a0..24bfedc92841 100644 --- a/torch/csrc/jit/tensorexpr/kernel.cpp +++ b/torch/csrc/jit/tensorexpr/kernel.cpp @@ -1019,54 +1019,30 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) { case aten::pow: { return computeTwoOperand( "aten_pow", v, [](const ExprHandle& lhs, const ExprHandle& rhs) { - const FloatImm* floatImm = rhs.AsNode(); - if (floatImm) { - float imm = floatImm->value(); - if (imm == 1.0f) { - return lhs; - } else if (imm == 2.0f) { // NOLINT - return lhs * lhs; - } else if (imm == 3.0f) { // NOLINT - return (lhs * lhs) * lhs; - } else if (imm == 4.0f) { // NOLINT - ExprHandle tmp = lhs * lhs; - return tmp * tmp; - } else if (imm == 0.5f) { // NOLINT - return sqrt(lhs); - } else if (imm == 0.0f) { - return ExprHandle(1.0f); - } else if (imm == -0.5f) { // NOLINT - return rsqrt(lhs); - } else if (imm == -1.0f) { - return ExprHandle(1.0f) / lhs; - } else if (imm == -2.0f) { // NOLINT - return ExprHandle(1.0f) / (lhs * lhs); - } + double val = 0; + if (rhs.node()->isConstant()) { + val = immediateAs(IRSimplifier::simplify(rhs.node())); } - const Cast* floatCast = rhs.AsNode(); - if (floatCast) { - const IntImm* intImm = - dynamic_cast(floatCast->src_value()); - if (intImm) { - float imm = static_cast(intImm->value()); - if (imm == 1) { - return lhs; - } else if (imm == 2) { - return lhs * lhs; - } else if (imm == 3) { - return (lhs * lhs) * lhs; - } else if (imm == 4) { - ExprHandle tmp = lhs * lhs; - return tmp * tmp; - } else if (imm == 0) { - return ExprHandle(1.0f); - } else if (imm == -1) { - return ExprHandle(1.0f) / lhs; - } else if (imm == -2) { - return ExprHandle(1.0f) / (lhs * lhs); - } - } + if (val == 1.0f) { + return lhs; + } else if (val == 2.0f) { // NOLINT + return lhs * lhs; + } else if (val == 3.0f) { // NOLINT + return (lhs * lhs) * lhs; + } else if (val == 4.0f) { // NOLINT + ExprHandle tmp = lhs * lhs; + return tmp * tmp; + } else if (val == 0.5f) { // NOLINT + return sqrt(lhs); + } else if (val == 0.0f) { + return ExprHandle(1.0f); + } else if (val == -0.5f) { // NOLINT + return rsqrt(lhs); + } else if (val == -1.0f) { + return ExprHandle(1.0f) / lhs; + } else if (val == -2.0f) { // NOLINT + return ExprHandle(1.0f) / (lhs * lhs); } return pow(lhs, rhs); }); From 8304c25c67ec19ba4e857ce3462e213835409078 Mon Sep 17 00:00:00 2001 From: Sam Estep Date: Thu, 12 Nov 2020 13:33:53 -0800 Subject: [PATCH 75/93] Give hash in commit messages in doc push scripts (#47694) Summary: This PR replaces the current auto-generated commit messages like pytorch/pytorch.github.io@fb217ab34abee258575613b13641939eca1b0fe1 (currently includes no information) and pytorch/cppdocs@7efd67e8f1ff6a599f0870785a1540efa515970a (currently includes only a timestamp, which is redundant since it's a Git commit) with more descriptive ones that specify the pytorch/pytorch commit they originated from. This information would be useful for debugging issues such as https://github.com/pytorch/pytorch/issues/47462. GitHub will also [autolink](https://docs.github.com/en/free-pro-team@latest/github/writing-on-github/autolinked-references-and-urls#commit-shas) these new messages (similar to ezyang/pytorch-ci-hud@bc25ae770d1088629fa17f5a5ce34aa94ce173e6), and so they will now also mostly follow Git commit message conventions by starting with a capital letter, using the imperative voice, and (at least in the autolink-rendered form on GitHub, although not in the raw text) staying under 50 characters. **Question for reviewers:** Will my `export CIRCLE_SHA1="$CIRCLE_SHA1"` work here? Is it necessary? Pull Request resolved: https://github.com/pytorch/pytorch/pull/47694 Reviewed By: walterddr Differential Revision: D24868240 Pulled By: samestep fbshipit-source-id: 4907341e7b57ed6818ab550dc1ec423f2c2450c1 --- .circleci/config.yml | 2 ++ .circleci/scripts/cpp_doc_push_script.sh | 2 +- .circleci/scripts/python_doc_push_script.sh | 2 +- .circleci/verbatim-sources/job-specs/job-specs-custom.yml | 2 ++ 4 files changed, 6 insertions(+), 2 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 2be90283d295..bc4ae033e8e6 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -1205,6 +1205,7 @@ jobs: export COMMAND='((echo "sudo chown -R jenkins workspace && cd workspace && . ./.circleci/scripts/python_doc_push_script.sh docs/'$target' master site") | docker exec -u jenkins -i "$id" bash) 2>&1' + export CIRCLE_SHA1="$CIRCLE_SHA1" echo ${COMMAND} > ./command.sh && unbuffer bash ./command.sh | ts mkdir -p ~/workspace/build_artifacts @@ -1250,6 +1251,7 @@ jobs: export COMMAND='((echo "sudo chown -R jenkins workspace && cd workspace && . ./.circleci/scripts/cpp_doc_push_script.sh docs/"$target" master") | docker exec -u jenkins -i "$id" bash) 2>&1' + export CIRCLE_SHA1="$CIRCLE_SHA1" echo ${COMMAND} > ./command.sh && unbuffer bash ./command.sh | ts mkdir -p ~/workspace/build_artifacts diff --git a/.circleci/scripts/cpp_doc_push_script.sh b/.circleci/scripts/cpp_doc_push_script.sh index 198e93d58f8e..c6b4f00a06f0 100755 --- a/.circleci/scripts/cpp_doc_push_script.sh +++ b/.circleci/scripts/cpp_doc_push_script.sh @@ -88,7 +88,7 @@ git status git config user.email "soumith+bot@pytorch.org" git config user.name "pytorchbot" # If there aren't changes, don't make a commit; push is no-op -git commit -m "Automatic sync on $(date)" || true +git commit -m "Generate C++ docs from pytorch/pytorch@$CIRCLE_SHA1" || true git status popd diff --git a/.circleci/scripts/python_doc_push_script.sh b/.circleci/scripts/python_doc_push_script.sh index 4da8d546d36a..9061eb9d1d85 100755 --- a/.circleci/scripts/python_doc_push_script.sh +++ b/.circleci/scripts/python_doc_push_script.sh @@ -107,7 +107,7 @@ git status git config user.email "soumith+bot@pytorch.org" git config user.name "pytorchbot" # If there aren't changes, don't make a commit; push is no-op -git commit -m "auto-generating sphinx docs" || true +git commit -m "Generate Python docs from pytorch/pytorch@$CIRCLE_SHA1" || true git status popd diff --git a/.circleci/verbatim-sources/job-specs/job-specs-custom.yml b/.circleci/verbatim-sources/job-specs/job-specs-custom.yml index d5f07eefb4e2..7fa561d66080 100644 --- a/.circleci/verbatim-sources/job-specs/job-specs-custom.yml +++ b/.circleci/verbatim-sources/job-specs/job-specs-custom.yml @@ -51,6 +51,7 @@ export COMMAND='((echo "sudo chown -R jenkins workspace && cd workspace && . ./.circleci/scripts/python_doc_push_script.sh docs/'$target' master site") | docker exec -u jenkins -i "$id" bash) 2>&1' + export CIRCLE_SHA1="$CIRCLE_SHA1" echo ${COMMAND} > ./command.sh && unbuffer bash ./command.sh | ts mkdir -p ~/workspace/build_artifacts @@ -96,6 +97,7 @@ export COMMAND='((echo "sudo chown -R jenkins workspace && cd workspace && . ./.circleci/scripts/cpp_doc_push_script.sh docs/"$target" master") | docker exec -u jenkins -i "$id" bash) 2>&1' + export CIRCLE_SHA1="$CIRCLE_SHA1" echo ${COMMAND} > ./command.sh && unbuffer bash ./command.sh | ts mkdir -p ~/workspace/build_artifacts From 65d5004b09fd8d5deac173a3aaa259f46eaa0d67 Mon Sep 17 00:00:00 2001 From: Sam Estep Date: Thu, 12 Nov 2020 13:43:36 -0800 Subject: [PATCH 76/93] Update, appease, and enable fail-on for shellcheck (#47786) Summary: Currently ([example](https://github.com/pytorch/pytorch/runs/1381883195)), ShellCheck is run on `*.sh` files in `.jenkins/pytorch`, but it uses a three-and-a-half-year-old version, and doesn't fail the lint job despite yielding many warnings. This PR does the following: - update ShellCheck to v0.7.1 (and generally make it always use the latest `"stable"` release), to get more warnings and also enable the directory-wide directives that were introduced in v0.7.0 (see the next bullet) - move the rule exclusions list from a variable in `.jenkins/run-shellcheck.sh` to a [declarative file](https://github.com/koalaman/shellcheck/issues/725#issuecomment-469102071) `.jenkins/pytorch/.shellcheckrc`, so now editor integrations such as [vscode-shellcheck](https://github.com/timonwong/vscode-shellcheck) give the same warnings as the CLI script - fix all ShellCheck warnings in `.jenkins/pytorch` - remove the suppression of ShellCheck's return value, so now it will fail the lint job if new warnings are introduced --- While working on this, I was confused because I was getting fairly different results from running ShellCheck locally versus what I saw in the CI logs, and also different results among the laptop and devservers I was using. Part of this was due to different versions of ShellCheck, but there were even differences within the same version. For instance, this command should reproduce the results in CI by using (almost) exactly the same environment: ```bash act -P ubuntu-latest=nektos/act-environments-ubuntu:18.04 -j quick-checks \ | sed '1,/Run Shellcheck Jenkins scripts/d;/Success - Shellcheck Jenkins scripts/,$d' \ | cut -c25- ``` But the various warnings were being displayed in different orders, so it was hard to tell at a glance whether I was getting the same result set or not. However, piping the results into this ShellCheck-output-sorting Python script showed that they were in fact the same: ```python import fileinput items = ''.join(fileinput.input()).split('\n\n') print(''.join(sorted(f'\n{item.strip()}\n\n' for item in items)), end='') ``` Note that while the above little script worked for the old version (v0.4.6) that was previously being used in CI, it is a bit brittle, and will not give great results in more recent ShellCheck versions (since they give more different kinds of output besides just a list of warnings). Pull Request resolved: https://github.com/pytorch/pytorch/pull/47786 Reviewed By: seemethere Differential Revision: D24900522 Pulled By: samestep fbshipit-source-id: 92d66e1d5d28a77de5a4274411598cdd28b7d436 --- .github/workflows/lint.yml | 7 ++++++- .jenkins/pytorch/.shellcheckrc | 6 ++++++ .jenkins/pytorch/build-mobile-code-analysis.sh | 1 + .jenkins/pytorch/build-mobile.sh | 1 + .jenkins/pytorch/build.sh | 2 +- .jenkins/pytorch/codegen-test.sh | 1 + .jenkins/pytorch/common.sh | 7 ++++--- .jenkins/pytorch/common_utils.sh | 2 +- .jenkins/pytorch/macos-build.sh | 4 ++-- .../perf_test/test_cpu_speed_mini_sequence_labeler.sh | 2 +- .jenkins/pytorch/perf_test/test_cpu_speed_mnist.sh | 2 +- .jenkins/pytorch/perf_test/test_gpu_speed_cudnn_lstm.sh | 2 +- .jenkins/pytorch/perf_test/test_gpu_speed_lstm.sh | 2 +- .jenkins/pytorch/perf_test/test_gpu_speed_mlstm.sh | 2 +- .jenkins/pytorch/perf_test/test_gpu_speed_mnist.sh | 2 +- .../perf_test/test_gpu_speed_word_language_model.sh | 2 +- .jenkins/pytorch/short-perf-test-cpu.sh | 5 ++--- .jenkins/pytorch/short-perf-test-gpu.sh | 5 ++--- .jenkins/pytorch/test.sh | 8 ++++---- .jenkins/pytorch/win-build.sh | 2 +- .jenkins/pytorch/win-test.sh | 2 +- .jenkins/run-shellcheck.sh | 4 +--- 22 files changed, 41 insertions(+), 30 deletions(-) create mode 100644 .jenkins/pytorch/.shellcheckrc diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 8fdccf101af7..1cdf9550d7a5 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -22,8 +22,13 @@ jobs: pip install -r requirements.txt cd .circleci && ./ensure-consistency.py - name: Shellcheck Jenkins scripts + # https://github.com/koalaman/shellcheck#installing-a-pre-compiled-binary run: | - sudo apt-get install -y shellcheck + scversion="stable" + wget -qO- "https://github.com/koalaman/shellcheck/releases/download/${scversion?}/shellcheck-${scversion?}.linux.x86_64.tar.xz" | tar -xJv + sudo cp "shellcheck-${scversion}/shellcheck" /usr/bin/ + rm -r "shellcheck-${scversion}" + shellcheck --version .jenkins/run-shellcheck.sh - name: Ensure no tabs run: | diff --git a/.jenkins/pytorch/.shellcheckrc b/.jenkins/pytorch/.shellcheckrc new file mode 100644 index 000000000000..ff96b057e50a --- /dev/null +++ b/.jenkins/pytorch/.shellcheckrc @@ -0,0 +1,6 @@ +disable=SC2086 +disable=SC1091 +disable=SC2155 +disable=SC1090 +disable=SC2164 +disable=SC1003 diff --git a/.jenkins/pytorch/build-mobile-code-analysis.sh b/.jenkins/pytorch/build-mobile-code-analysis.sh index 982ab257f84d..0e6d4be88be3 100755 --- a/.jenkins/pytorch/build-mobile-code-analysis.sh +++ b/.jenkins/pytorch/build-mobile-code-analysis.sh @@ -5,6 +5,7 @@ set -eu -o pipefail # This script builds and runs code analyzer tool to generate aten op dependency # graph for custom mobile build. +# shellcheck disable=SC2034 COMPACT_JOB_NAME="${BUILD_ENVIRONMENT}" source "$(dirname "${BASH_SOURCE[0]}")/common.sh" diff --git a/.jenkins/pytorch/build-mobile.sh b/.jenkins/pytorch/build-mobile.sh index b1234f272813..3ffec5074171 100755 --- a/.jenkins/pytorch/build-mobile.sh +++ b/.jenkins/pytorch/build-mobile.sh @@ -6,6 +6,7 @@ set -eu -o pipefail # build & test mobile libtorch without having to setup Android/iOS # toolchain/simulator. +# shellcheck disable=SC2034 COMPACT_JOB_NAME="${BUILD_ENVIRONMENT}" source "$(dirname "${BASH_SOURCE[0]}")/common.sh" diff --git a/.jenkins/pytorch/build.sh b/.jenkins/pytorch/build.sh index b6e21c363133..b2b71b559bdf 100755 --- a/.jenkins/pytorch/build.sh +++ b/.jenkins/pytorch/build.sh @@ -165,7 +165,7 @@ fi # sccache will fail for CUDA builds if all cores are used for compiling # gcc 7 with sccache seems to have intermittent OOM issue if all cores are used if [ -z "$MAX_JOBS" ]; then - if ([[ "$BUILD_ENVIRONMENT" == *cuda* ]] || [[ "$BUILD_ENVIRONMENT" == *gcc7* ]]) && which sccache > /dev/null; then + if { [[ "$BUILD_ENVIRONMENT" == *cuda* ]] || [[ "$BUILD_ENVIRONMENT" == *gcc7* ]]; } && which sccache > /dev/null; then export MAX_JOBS=$(($(nproc) - 1)) fi fi diff --git a/.jenkins/pytorch/codegen-test.sh b/.jenkins/pytorch/codegen-test.sh index 0b3729478bcd..dee29d36fc9f 100755 --- a/.jenkins/pytorch/codegen-test.sh +++ b/.jenkins/pytorch/codegen-test.sh @@ -12,6 +12,7 @@ set -eu -o pipefail if [ "$#" -eq 0 ]; then + # shellcheck disable=SC2034 COMPACT_JOB_NAME="${BUILD_ENVIRONMENT}" source "$(dirname "${BASH_SOURCE[0]}")/common.sh" OUT="$(dirname "${BASH_SOURCE[0]}")/../../codegen_result" diff --git a/.jenkins/pytorch/common.sh b/.jenkins/pytorch/common.sh index 96e3f5d8ede1..3175c3fa75af 100644 --- a/.jenkins/pytorch/common.sh +++ b/.jenkins/pytorch/common.sh @@ -18,7 +18,7 @@ if [[ "${BUILD_ENVIRONMENT}" == *rocm* ]] && [[ "${BUILD_ENVIRONMENT}" =~ py((2| # non-interactive bashs do not expand aliases by default shopt -s expand_aliases export PYTORCH_TEST_WITH_ROCM=1 - alias python="$PYTHON" + alias python='$PYTHON' # temporary to locate some kernel issues on the CI nodes export HSAKMT_DEBUG_LEVEL=4 fi @@ -45,7 +45,7 @@ fatal() { error "$@"; exit 1; } # - remaining args: names of traps to modify # trap_add() { - trap_add_cmd=$1; shift || fatal "${FUNCNAME} usage error" + trap_add_cmd=$1; shift || fatal "${FUNCNAME[0]} usage error" for trap_add_name in "$@"; do trap -- "$( # helper fn to get existing trap command from output @@ -116,6 +116,7 @@ if [[ "$BUILD_ENVIRONMENT" == *pytorch-linux-xenial-cuda10.1-cudnn7-py3* ]] || \ [[ "$BUILD_ENVIRONMENT" == *pytorch_macos* ]]; then BUILD_TEST_LIBTORCH=1 else + # shellcheck disable=SC2034 BUILD_TEST_LIBTORCH=0 fi @@ -138,5 +139,5 @@ if [[ "$BUILD_ENVIRONMENT" == *pytorch-xla-linux-bionic* ]] || \ fi retry () { - $* || (sleep 1 && $*) || (sleep 2 && $*) + "$@" || (sleep 1 && "$@") || (sleep 2 && "$@") } diff --git a/.jenkins/pytorch/common_utils.sh b/.jenkins/pytorch/common_utils.sh index 24d6f5676f7d..b28dcb2f41d8 100644 --- a/.jenkins/pytorch/common_utils.sh +++ b/.jenkins/pytorch/common_utils.sh @@ -18,7 +18,7 @@ function cleanup { function assert_git_not_dirty() { # TODO: we should add an option to `build_amd.py` that reverts the repo to # an unmodified state. - if ([[ "$BUILD_ENVIRONMENT" != *rocm* ]] && [[ "$BUILD_ENVIRONMENT" != *xla* ]]) ; then + if [[ "$BUILD_ENVIRONMENT" != *rocm* ]] && [[ "$BUILD_ENVIRONMENT" != *xla* ]] ; then git_status=$(git status --porcelain) if [[ $git_status ]]; then echo "Build left local git repository checkout dirty" diff --git a/.jenkins/pytorch/macos-build.sh b/.jenkins/pytorch/macos-build.sh index 140c29cc9642..25bf368e86ef 100755 --- a/.jenkins/pytorch/macos-build.sh +++ b/.jenkins/pytorch/macos-build.sh @@ -13,10 +13,10 @@ if [ -z "${IN_CI}" ]; then fi if which sccache > /dev/null; then - printf "#!/bin/sh\nexec sccache $(which clang++) \$*" > "${WORKSPACE_DIR}/clang++" + printf "#!/bin/sh\nexec sccache %s \$*" "$(which clang++)" > "${WORKSPACE_DIR}/clang++" chmod a+x "${WORKSPACE_DIR}/clang++" - printf "#!/bin/sh\nexec sccache $(which clang) \$*" > "${WORKSPACE_DIR}/clang" + printf "#!/bin/sh\nexec sccache %s \$*" "$(which clang)" > "${WORKSPACE_DIR}/clang" chmod a+x "${WORKSPACE_DIR}/clang" export PATH="${WORKSPACE_DIR}:$PATH" diff --git a/.jenkins/pytorch/perf_test/test_cpu_speed_mini_sequence_labeler.sh b/.jenkins/pytorch/perf_test/test_cpu_speed_mini_sequence_labeler.sh index 795251fc8625..4f86eb88fe0c 100644 --- a/.jenkins/pytorch/perf_test/test_cpu_speed_mini_sequence_labeler.sh +++ b/.jenkins/pytorch/perf_test/test_cpu_speed_mini_sequence_labeler.sh @@ -21,7 +21,7 @@ test_cpu_speed_mini_sequence_labeler () { for (( i=1; i<=NUM_RUNS; i++ )) do runtime=$(get_runtime_of_command python main.py) - SAMPLE_ARRAY+=(${runtime}) + SAMPLE_ARRAY+=("${runtime}") done cd ../../.. diff --git a/.jenkins/pytorch/perf_test/test_cpu_speed_mnist.sh b/.jenkins/pytorch/perf_test/test_cpu_speed_mnist.sh index 29086fbc9976..e284bb3aa6cc 100644 --- a/.jenkins/pytorch/perf_test/test_cpu_speed_mnist.sh +++ b/.jenkins/pytorch/perf_test/test_cpu_speed_mnist.sh @@ -23,7 +23,7 @@ test_cpu_speed_mnist () { for (( i=1; i<=NUM_RUNS; i++ )) do runtime=$(get_runtime_of_command python main.py --epochs 1 --no-log) echo $runtime - SAMPLE_ARRAY+=(${runtime}) + SAMPLE_ARRAY+=("${runtime}") done cd ../.. diff --git a/.jenkins/pytorch/perf_test/test_gpu_speed_cudnn_lstm.sh b/.jenkins/pytorch/perf_test/test_gpu_speed_cudnn_lstm.sh index 667cfba617fc..25109fdb8428 100644 --- a/.jenkins/pytorch/perf_test/test_gpu_speed_cudnn_lstm.sh +++ b/.jenkins/pytorch/perf_test/test_gpu_speed_cudnn_lstm.sh @@ -22,7 +22,7 @@ test_gpu_speed_cudnn_lstm () { for (( i=1; i<=NUM_RUNS; i++ )) do runtime=$(get_runtime_of_command python cudnn_lstm.py --skip-cpu-governor-check) echo $runtime - SAMPLE_ARRAY+=(${runtime}) + SAMPLE_ARRAY+=("${runtime}") done cd ../.. diff --git a/.jenkins/pytorch/perf_test/test_gpu_speed_lstm.sh b/.jenkins/pytorch/perf_test/test_gpu_speed_lstm.sh index ea220b33ac7c..e0f629cde86e 100644 --- a/.jenkins/pytorch/perf_test/test_gpu_speed_lstm.sh +++ b/.jenkins/pytorch/perf_test/test_gpu_speed_lstm.sh @@ -22,7 +22,7 @@ test_gpu_speed_lstm () { for (( i=1; i<=NUM_RUNS; i++ )) do runtime=$(get_runtime_of_command python lstm.py --skip-cpu-governor-check) echo $runtime - SAMPLE_ARRAY+=(${runtime}) + SAMPLE_ARRAY+=("${runtime}") done cd ../.. diff --git a/.jenkins/pytorch/perf_test/test_gpu_speed_mlstm.sh b/.jenkins/pytorch/perf_test/test_gpu_speed_mlstm.sh index 62b94a7b21d1..46bb2b5ba2e3 100644 --- a/.jenkins/pytorch/perf_test/test_gpu_speed_mlstm.sh +++ b/.jenkins/pytorch/perf_test/test_gpu_speed_mlstm.sh @@ -22,7 +22,7 @@ test_gpu_speed_mlstm () { for (( i=1; i<=NUM_RUNS; i++ )) do runtime=$(get_runtime_of_command python mlstm.py --skip-cpu-governor-check) echo $runtime - SAMPLE_ARRAY+=(${runtime}) + SAMPLE_ARRAY+=("${runtime}") done cd ../.. diff --git a/.jenkins/pytorch/perf_test/test_gpu_speed_mnist.sh b/.jenkins/pytorch/perf_test/test_gpu_speed_mnist.sh index 2453f3d70e70..2868cfca30c3 100644 --- a/.jenkins/pytorch/perf_test/test_gpu_speed_mnist.sh +++ b/.jenkins/pytorch/perf_test/test_gpu_speed_mnist.sh @@ -26,7 +26,7 @@ test_gpu_speed_mnist () { for (( i=1; i<=NUM_RUNS; i++ )) do runtime=$(get_runtime_of_command python main.py --epochs 1 --no-log) echo $runtime - SAMPLE_ARRAY+=(${runtime}) + SAMPLE_ARRAY+=("${runtime}") done cd ../.. diff --git a/.jenkins/pytorch/perf_test/test_gpu_speed_word_language_model.sh b/.jenkins/pytorch/perf_test/test_gpu_speed_word_language_model.sh index b1dea09c7c34..d0ae3160a22b 100644 --- a/.jenkins/pytorch/perf_test/test_gpu_speed_word_language_model.sh +++ b/.jenkins/pytorch/perf_test/test_gpu_speed_word_language_model.sh @@ -31,7 +31,7 @@ test_gpu_speed_word_language_model () { for (( i=1; i<=NUM_RUNS; i++ )) do runtime=$(get_runtime_of_command python main.py --cuda --epochs 1) echo $runtime - SAMPLE_ARRAY+=(${runtime}) + SAMPLE_ARRAY+=("${runtime}") done cd ../.. diff --git a/.jenkins/pytorch/short-perf-test-cpu.sh b/.jenkins/pytorch/short-perf-test-cpu.sh index ae838276bd4d..a77e6245e355 100755 --- a/.jenkins/pytorch/short-perf-test-cpu.sh +++ b/.jenkins/pytorch/short-perf-test-cpu.sh @@ -27,13 +27,12 @@ fi git remote add upstream https://github.com/pytorch/pytorch.git git fetch upstream IFS=$'\n' -master_commit_ids=($(git rev-list upstream/master)) -for commit_id in "${master_commit_ids[@]}"; do +while IFS='' read -r commit_id; do if aws s3 ls s3://ossci-perf-test/pytorch/cpu_runtime/${commit_id}.json; then LATEST_TESTED_COMMIT=${commit_id} break fi -done +done < <(git rev-list upstream/master) aws s3 cp s3://ossci-perf-test/pytorch/cpu_runtime/${LATEST_TESTED_COMMIT}.json cpu_runtime.json if [[ "$COMMIT_SOURCE" == master ]]; then diff --git a/.jenkins/pytorch/short-perf-test-gpu.sh b/.jenkins/pytorch/short-perf-test-gpu.sh index 8fd701e19720..ec445409390b 100755 --- a/.jenkins/pytorch/short-perf-test-gpu.sh +++ b/.jenkins/pytorch/short-perf-test-gpu.sh @@ -26,13 +26,12 @@ fi git remote add upstream https://github.com/pytorch/pytorch.git git fetch upstream IFS=$'\n' -master_commit_ids=($(git rev-list upstream/master)) -for commit_id in "${master_commit_ids[@]}"; do +while IFS='' read -r commit_id; do if aws s3 ls s3://ossci-perf-test/pytorch/gpu_runtime/${commit_id}.json; then LATEST_TESTED_COMMIT=${commit_id} break fi -done +done < <(git rev-list upstream/master) aws s3 cp s3://ossci-perf-test/pytorch/gpu_runtime/${LATEST_TESTED_COMMIT}.json gpu_runtime.json if [[ "$COMMIT_SOURCE" == master ]]; then diff --git a/.jenkins/pytorch/test.sh b/.jenkins/pytorch/test.sh index 78ba67c088ee..1f1f174e992e 100755 --- a/.jenkins/pytorch/test.sh +++ b/.jenkins/pytorch/test.sh @@ -22,7 +22,7 @@ fi if [[ "$BUILD_ENVIRONMENT" == *rocm* ]]; then # Print GPU info - rocminfo | egrep 'Name:.*\sgfx|Marketing' + rocminfo | grep -E 'Name:.*\sgfx|Marketing' fi # --user breaks ppc64le builds and these packages are already in ppc64le docker @@ -93,7 +93,7 @@ elif [[ "${BUILD_ENVIRONMENT}" == *-NO_AVX2-* ]]; then export ATEN_CPU_CAPABILITY=avx fi -if ([ -n "$CIRCLE_PULL_REQUEST" ] && [[ "$BUILD_ENVIRONMENT" != *coverage* ]]); then +if [ -n "$CIRCLE_PULL_REQUEST" ] && [[ "$BUILD_ENVIRONMENT" != *coverage* ]]; then DETERMINE_FROM=$(mktemp) file_diff_from_base "$DETERMINE_FROM" fi @@ -117,7 +117,7 @@ test_aten() { # Test ATen # The following test(s) of ATen have already been skipped by caffe2 in rocm environment: # scalar_tensor_test, basic, native_test - if ([[ "$BUILD_ENVIRONMENT" != *asan* ]] && [[ "$BUILD_ENVIRONMENT" != *rocm* ]]); then + if [[ "$BUILD_ENVIRONMENT" != *asan* ]] && [[ "$BUILD_ENVIRONMENT" != *rocm* ]]; then echo "Running ATen tests with pytorch lib" TORCH_LIB_PATH=$(python -c "import site; print(site.getsitepackages()[0])")/torch/lib # NB: the ATen test binaries don't have RPATH set, so it's necessary to @@ -255,7 +255,7 @@ test_torch_function_benchmark() { test_xla() { export XLA_USE_XRT=1 XRT_DEVICE_MAP="CPU:0;/job:localservice/replica:0/task:0/device:XLA_CPU:0" # Issue #30717: randomize the port of XLA/gRPC workers is listening on to reduce flaky tests. - XLA_PORT=`shuf -i 40701-40999 -n 1` + XLA_PORT=$(shuf -i 40701-40999 -n 1) export XRT_WORKERS="localservice:0;grpc://localhost:$XLA_PORT" pushd xla echo "Running Python Tests" diff --git a/.jenkins/pytorch/win-build.sh b/.jenkins/pytorch/win-build.sh index 6df36ce4d021..285b5c4e94d4 100755 --- a/.jenkins/pytorch/win-build.sh +++ b/.jenkins/pytorch/win-build.sh @@ -15,7 +15,7 @@ COMPACT_JOB_NAME=pytorch-win-ws2019-cuda10-cudnn7-py3-build SCRIPT_PARENT_DIR=$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd ) source "$SCRIPT_PARENT_DIR/common.sh" -export IMAGE_COMMIT_ID=`git rev-parse HEAD` +export IMAGE_COMMIT_ID=$(git rev-parse HEAD) export IMAGE_COMMIT_TAG=${BUILD_ENVIRONMENT}-${IMAGE_COMMIT_ID} if [[ ${JOB_NAME} == *"develop"* ]]; then export IMAGE_COMMIT_TAG=develop-${IMAGE_COMMIT_TAG} diff --git a/.jenkins/pytorch/win-test.sh b/.jenkins/pytorch/win-test.sh index 768e4931e485..c1c49cd711b8 100755 --- a/.jenkins/pytorch/win-test.sh +++ b/.jenkins/pytorch/win-test.sh @@ -6,7 +6,7 @@ COMPACT_JOB_NAME=pytorch-win-ws2019-cuda10-cudnn7-py3-test SCRIPT_PARENT_DIR=$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd ) source "$SCRIPT_PARENT_DIR/common.sh" -export IMAGE_COMMIT_ID=`git rev-parse HEAD` +export IMAGE_COMMIT_ID=$(git rev-parse HEAD) export IMAGE_COMMIT_TAG=${BUILD_ENVIRONMENT}-${IMAGE_COMMIT_ID} if [[ ${JOB_NAME} == *"develop"* ]]; then export IMAGE_COMMIT_TAG=develop-${IMAGE_COMMIT_TAG} diff --git a/.jenkins/run-shellcheck.sh b/.jenkins/run-shellcheck.sh index 1333e9ab6f49..5c64655b578f 100755 --- a/.jenkins/run-shellcheck.sh +++ b/.jenkins/run-shellcheck.sh @@ -5,6 +5,4 @@ # .jenkins/run-shellcheck.sh --color=always | less -R -EXCLUSIONS=SC2086,SC1091,SC2155,SC1090,SC2164,SC1003 - -find .jenkins/pytorch -name *.sh | xargs shellcheck --exclude=$EXCLUSIONS --external-sources "$@" || true +find .jenkins/pytorch -name *.sh | xargs shellcheck --external-sources "$@" From 8da75763032fb74f16c92fd9a8c98e69cd0b9ce3 Mon Sep 17 00:00:00 2001 From: Pritam Damania Date: Thu, 12 Nov 2020 14:18:17 -0800 Subject: [PATCH 77/93] Remove `balance` and `devices` parameter from Pipe. (#46804) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/46804 As per our design in https://github.com/pytorch/pytorch/issues/44827, changign the API such that the user places modules on appropriate devices instead of having a `balance` and `devices` parameter that decides this. This design allows us to use RemoteModule in the future. ghstack-source-id: 116479842 Test Plan: waitforbuildbot Reviewed By: mrshenli Differential Revision: D24524219 fbshipit-source-id: 9973172c2bb7636572cdc37ce06bf8368638a463 --- .../_pipeline/sync/skip/test_gpipe.py | 6 +- .../_pipeline/sync/skip/test_leak.py | 2 +- test/distributed/_pipeline/sync/test_bugs.py | 20 ++- .../_pipeline/sync/test_inplace.py | 2 +- test/distributed/_pipeline/sync/test_pipe.py | 167 +++++++----------- .../_pipeline/sync/test_transparency.py | 2 +- torch/distributed/_pipeline/sync/pipe.py | 144 +++++---------- .../distributed/pipeline/__init__.py | 0 .../_internal/distributed/pipeline/utils.py | 21 +++ 9 files changed, 155 insertions(+), 209 deletions(-) create mode 100644 torch/testing/_internal/distributed/pipeline/__init__.py create mode 100644 torch/testing/_internal/distributed/pipeline/utils.py diff --git a/test/distributed/_pipeline/sync/skip/test_gpipe.py b/test/distributed/_pipeline/sync/skip/test_gpipe.py index 293a263439bc..96ecd84e0d18 100644 --- a/test/distributed/_pipeline/sync/skip/test_gpipe.py +++ b/test/distributed/_pipeline/sync/skip/test_gpipe.py @@ -11,6 +11,7 @@ from torch.distributed._pipeline.sync import Pipe from torch.distributed._pipeline.sync.skip import pop, skippable, stash from torch.distributed._pipeline.sync.skip.portal import PortalBlue, PortalCopy, PortalOrange +from torch.testing._internal.distributed.pipeline.utils import convert_to_balance @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") @@ -52,7 +53,8 @@ def forward(self, input): return output model = nn.Sequential(Layer1(), Layer2(), Layer3()) - model = Pipe(model, balance, chunks=3, checkpoint=checkpoint) + model = convert_to_balance(model, balance) + model = Pipe(model, chunks=3, checkpoint=checkpoint) in_device = model.devices[0] out_device = model.devices[-1] @@ -81,7 +83,7 @@ def forward(self, input): return input model = nn.Sequential(Stash(), Pop()) - model = Pipe(model, [1, 1], devices=["cpu", "cpu"], chunks=5) + model = Pipe(model, chunks=5) input = torch.rand(10, requires_grad=True) output = model(input) diff --git a/test/distributed/_pipeline/sync/skip/test_leak.py b/test/distributed/_pipeline/sync/skip/test_leak.py index 89e39aa9cedb..31c4ea13b9f1 100644 --- a/test/distributed/_pipeline/sync/skip/test_leak.py +++ b/test/distributed/_pipeline/sync/skip/test_leak.py @@ -91,7 +91,7 @@ def forward(self, input): return self.F.apply(input) model = nn.Sequential(NoPortalTensorAtBackward(), stash_, pop_) - model = Pipe(model, balance=[2, 1], devices=["cpu", "cpu"], chunks=2, checkpoint=checkpoint) + model = Pipe(model, chunks=2, checkpoint=checkpoint) input = torch.rand(10, requires_grad=True) diff --git a/test/distributed/_pipeline/sync/test_bugs.py b/test/distributed/_pipeline/sync/test_bugs.py index c3152745b5bb..4f5346a837b5 100644 --- a/test/distributed/_pipeline/sync/test_bugs.py +++ b/test/distributed/_pipeline/sync/test_bugs.py @@ -37,7 +37,7 @@ def forward(self, input): return Identity.apply(input) model = nn.Sequential(M(), M()) - model = Pipe(model, [1, 1], devices=["cpu", "cpu"], checkpoint="always") + model = Pipe(model, checkpoint="always") x = torch.rand(42) y = model(x) @@ -62,7 +62,7 @@ def forward(self, x): raise ExpectedException() model = nn.Sequential(Pass(), Pass(), Raise()) - model = Pipe(model, [1, 1, 1], devices=["cpu", "cpu", "cpu"], chunks=3) + model = Pipe(model, chunks=3) with pytest.raises(ExpectedException): model(torch.rand(3)) @@ -86,18 +86,28 @@ def backward(ctx, grad): return grad class Layer1(nn.Module): + def __init__(self): + super().__init__() + self.ones = nn.Parameter(torch.ones(32, 3, 32, 32, requires_grad=True)) + def forward(self, pair): a, b = pair + a = a * self.ones return a * 1, b * 2, b * 3 class Layer2(nn.Module): + def __init__(self): + super().__init__() + self.ones = nn.Parameter(torch.ones(32, 3, 32, 32, requires_grad=True)) + def forward(self, triple): a, b, c = triple + a = a * self.ones b = Sleep.apply(b) return a + b + c - model = nn.Sequential(Layer1(), Layer2()) - model = Pipe(model, [1, 1], devices=[0, 1], chunks=32, checkpoint="never") + model = nn.Sequential(Layer1().cuda(0), Layer2().cuda(1)) + model = Pipe(model, chunks=32, checkpoint="never") a = torch.rand(1024, 3, 32, 32, device=0, requires_grad=True) b = torch.rand(1024, 3, 32, 32, device=0, requires_grad=True) @@ -121,7 +131,7 @@ def forward(self, x): model = nn.Sequential(Dropouts(), Dropouts()) x = torch.rand(10, 10, requires_grad=True) - model = Pipe(model, [1, 1], devices=["cpu", "cpu"], chunks=10, checkpoint="always") + model = Pipe(model, chunks=10, checkpoint="always") y = model(x) y.norm().backward() diff --git a/test/distributed/_pipeline/sync/test_inplace.py b/test/distributed/_pipeline/sync/test_inplace.py index 185ad8706054..17b3dac4eca8 100644 --- a/test/distributed/_pipeline/sync/test_inplace.py +++ b/test/distributed/_pipeline/sync/test_inplace.py @@ -13,7 +13,7 @@ def test_inplace_on_requires_grad(): model = nn.Sequential(nn.Linear(1, 1), nn.ReLU(inplace=True)) - model = Pipe(model, [1, 1], devices=["cpu", "cpu"], checkpoint="always") + model = Pipe(model, checkpoint="always") x = torch.rand(1) y = model(x) diff --git a/test/distributed/_pipeline/sync/test_pipe.py b/test/distributed/_pipeline/sync/test_pipe.py index d7915733adc0..9c2964940576 100644 --- a/test/distributed/_pipeline/sync/test_pipe.py +++ b/test/distributed/_pipeline/sync/test_pipe.py @@ -19,7 +19,7 @@ def test_parameters(): model = nn.Sequential(nn.Linear(1, 1)) - pipe = Pipe(model, balance=[1], devices=["cpu"], chunks=1) + pipe = Pipe(model, chunks=1) assert list(pipe.parameters()) != [] @@ -32,9 +32,8 @@ def __str__(self): return self.value model = nn.Sequential(nn.Linear(1, 1)) - pipe = Pipe(model, balance=(1,), devices=("cpu",), chunks=42.000, checkpoint=MyString("always")) + pipe = Pipe(model, chunks=42.000, checkpoint=MyString("always")) - assert pipe.balance == [1] assert pipe.devices == [torch.device("cpu")] assert pipe.chunks == 42 assert isinstance(pipe.chunks, int) @@ -42,13 +41,12 @@ def __str__(self): assert isinstance(pipe.checkpoint, str) -@pytest.mark.parametrize("balance", [[2], [1, 1]]) -def test_sequential_like(balance): +def test_sequential_like(): a = nn.Linear(1, 1) b = nn.Linear(1, 1) model = nn.Sequential(a, b) - model = Pipe(model, balance, devices=["cpu", "cpu"]) + model = Pipe(model) assert len(model) == 2 assert list(model) == [a, b] @@ -61,54 +59,18 @@ def test_sequential_like(balance): assert model[-1] is b assert model[-2] is a - -def test_balance_wrong_length(): - a = nn.Linear(1, 1) - b = nn.Linear(1, 1) - - model = nn.Sequential(a, b) - - with pytest.raises(ValueError): - Pipe(model, balance=[1]) - - with pytest.raises(ValueError): - Pipe(model, balance=[3]) - - -def test_balance_less_than_1(): - a = nn.Linear(1, 1) - b = nn.Linear(1, 1) - - model = nn.Sequential(a, b) - - with pytest.raises(ValueError): - Pipe(model, balance=[0, 2]) - - with pytest.raises(ValueError): - Pipe(model, balance=[-1, 3]) - - def test_chunks_less_than_1(): model = nn.Sequential(nn.Linear(1, 1)) with pytest.raises(ValueError): - Pipe(model, balance=[1], devices=["cpu"], chunks=0) + Pipe(model, chunks=0) with pytest.raises(ValueError): - Pipe(model, balance=[1], devices=["cpu"], chunks=-1) - - -def test_too_few_devices(): - model = nn.Sequential(nn.Linear(1, 1), nn.Linear(1, 1), nn.Linear(1, 1), nn.Linear(1, 1)) - - with pytest.raises(IndexError): - # len(balance) > len(devices) - model = Pipe(model, balance=[1, 1, 1, 1], devices=["cpu"]) - + Pipe(model, chunks=-1) def test_batch_size_indivisible(): model = nn.Sequential(nn.Linear(1, 1)) - model = Pipe(model, balance=[1], devices=["cpu"], chunks=4) + model = Pipe(model, chunks=4) with pytest.warns(None) as record: model(torch.rand(7, 1)) @@ -119,7 +81,7 @@ def test_batch_size_indivisible(): def test_batch_size_small(): model = nn.Sequential(nn.Linear(1, 1)) - model = Pipe(model, balance=[1], devices=["cpu"], chunks=4) + model = Pipe(model, chunks=4) with pytest.warns(None) as record: model(torch.rand(2, 1)) @@ -149,9 +111,9 @@ def count_grad_fn(grad_fn, name, visited=None): model = nn.Sequential(nn.Linear(1, 1)) input = torch.rand(2, 1) - always = Pipe(model, balance=[1], devices=["cpu"], chunks=2, checkpoint="always") - except_last = Pipe(model, balance=[1], devices=["cpu"], chunks=2, checkpoint="except_last") - never = Pipe(model, balance=[1], devices=["cpu"], chunks=2, checkpoint="never") + always = Pipe(model, chunks=2, checkpoint="always") + except_last = Pipe(model, chunks=2, checkpoint="except_last") + never = Pipe(model, chunks=2, checkpoint="never") always_output = always(input) except_last_output = except_last(input) @@ -166,21 +128,21 @@ def test_checkpoint_mode_invalid(): model = nn.Sequential(nn.Linear(1, 1)) with pytest.raises(ValueError, match="checkpoint is not one of 'always', 'except_last', or 'never'"): - Pipe(model, balance=[1], devices=["cpu"], chunks=2, checkpoint="INVALID_CHECKPOINT") + Pipe(model, chunks=2, checkpoint="INVALID_CHECKPOINT") def test_checkpoint_mode_when_chunks_1(): model = nn.Sequential(nn.Linear(1, 1)) # All checkpoint modes are fine. - Pipe(model, balance=[1], devices=["cpu"], chunks=1, checkpoint="except_last") - Pipe(model, balance=[1], devices=["cpu"], chunks=1, checkpoint="always") - Pipe(model, balance=[1], devices=["cpu"], chunks=1, checkpoint="never") + Pipe(model, chunks=1, checkpoint="except_last") + Pipe(model, chunks=1, checkpoint="always") + Pipe(model, chunks=1, checkpoint="never") def test_checkpoint_eval(): model = nn.Sequential(nn.Linear(1, 1)) - model = Pipe(model, balance=[1], devices=["cpu"], chunks=2) + model = Pipe(model, chunks=2) input = torch.rand(2, 1) def find_grad_fn(grad_fn, name): @@ -214,7 +176,7 @@ def forward(self, input): return input[0] * 2 model = nn.Sequential(ForkNonFloat(), JoinNonFloat()) - model = Pipe(model, balance=[1, 1], devices=["cpu", "cpu"], chunks=1, checkpoint="always") + model = Pipe(model, chunks=1, checkpoint="always") input = torch.rand(1, requires_grad=True) output = model(input) @@ -223,7 +185,7 @@ def forward(self, input): def test_no_grad(): model = nn.Sequential(nn.Linear(1, 1)) - model = Pipe(model, balance=[1], devices=["cpu"], chunks=2) + model = Pipe(model, chunks=2) input = torch.rand(2, 1) latent = None @@ -253,7 +215,7 @@ def forward(self, *_): raise ExpectedException() model = nn.Sequential(Raise()) - model = Pipe(model, balance=[1], devices=["cpu"], chunks=1) + model = Pipe(model, chunks=1) with pytest.raises(ExpectedException): model(torch.rand(1)) @@ -287,7 +249,7 @@ def forward(self, x): raise ExpectedException() model = nn.Sequential(Pass(), Pass(), Counter(), Raise()) - model = Pipe(model, [1, 1, 1, 1], devices=["cpu", "cpu", "cpu", "cpu"], chunks=3) + model = Pipe(model, chunks=3) with pytest.raises(ExpectedException): model(torch.rand(3)) @@ -308,7 +270,7 @@ def forward(self, a_and_b): return (self.fc_a(a), self.fc_b(b)) model = nn.Sequential(Two()) - model = Pipe(model, balance=[1], devices=["cpu"], chunks=2) + model = Pipe(model, chunks=2) a = torch.rand(10, 1, requires_grad=True) b = torch.rand(10, 1, requires_grad=True) @@ -332,7 +294,7 @@ def forward(self, only_a): return (self.fc(a),) model = nn.Sequential(One()) - model = Pipe(model, balance=[1], devices=["cpu"], chunks=2) + model = Pipe(model, chunks=2) a = torch.rand(10, 1, requires_grad=True) @@ -346,7 +308,7 @@ def forward(self, only_a): def test_input_varargs(): model = nn.Sequential(nn.Linear(1, 1)) - model = Pipe(model, balance=[1], devices=["cpu"]) + model = Pipe(model) a = torch.rand(1) b = torch.rand(1) @@ -362,7 +324,7 @@ def forward(self, _): return "hello" model = nn.Sequential(NonTensor()) - model = Pipe(model, balance=[1], devices=["cpu"]) + model = Pipe(model) x = torch.rand(1) # TypeError: expected Tensor as element 0 in argument 0, but got str @@ -380,7 +342,7 @@ def forward(self, x): return (x, "hello") model = nn.Sequential(NonTensorTuple()) - model = Pipe(model, balance=[1], devices=["cpu"]) + model = Pipe(model) x = torch.rand(1) # TypeError: CheckpointBackward.forward: expected Variable (got str) for return value 1 @@ -397,7 +359,7 @@ def test_deferred_batch_norm(checkpoint): bn = nn.BatchNorm2d(3) pipe_bn = deepcopy(bn) pipe = Pipe( - nn.Sequential(pipe_bn), balance=[1], devices=["cpu"], chunks=2, checkpoint=checkpoint, deferred_batch_norm=True + nn.Sequential(pipe_bn), chunks=2, checkpoint=checkpoint, deferred_batch_norm=True ) x = torch.rand(4, 3, 10, 10) @@ -413,7 +375,7 @@ def test_deferred_batch_norm_params(checkpoint): bn = nn.BatchNorm2d(3) pipe_bn = deepcopy(bn) pipe = Pipe( - nn.Sequential(pipe_bn), balance=[1], devices=["cpu"], chunks=1, checkpoint=checkpoint, deferred_batch_norm=True + nn.Sequential(pipe_bn), chunks=1, checkpoint=checkpoint, deferred_batch_norm=True ) x = torch.rand(4, 3, 10, 10) @@ -433,10 +395,8 @@ def test_devices(): c = nn.Linear(1, 1) # There are extra two devices. - devices = ["cpu", "cpu", "cpu", "cpu", "cpu"] - model = nn.Sequential(a, b, c) - model = Pipe(model, [1, 1, 1], devices=devices) + model = Pipe(model) cpu = torch.device("cpu") # Extra devices must be discarded. @@ -448,7 +408,7 @@ def test_partitions(): b = nn.Linear(1, 1) model = nn.Sequential(a, b) - model = Pipe(model, [1, 1], devices=["cpu", "cpu"]) + model = Pipe(model) assert isinstance(model.partitions, nn.ModuleList) assert isinstance(model.partitions[0], nn.Sequential) @@ -462,7 +422,7 @@ def test_deny_moving(): b = nn.Linear(1, 1) model = nn.Sequential(a, b) - model = Pipe(model, [1, 1], devices=["cpu", "cpu"]) + model = Pipe(model) # Moving is denied. with pytest.raises(TypeError): @@ -498,7 +458,7 @@ def test_deny_moving(): def test_empty_module(): # Empty sequential module is not illegal. model = nn.Sequential() - model = Pipe(model, []) + model = Pipe(model) assert model(torch.tensor(42)) == torch.tensor(42) assert model((torch.tensor(42),)) == (torch.tensor(42),) @@ -513,7 +473,7 @@ def test_named_children(): b = nn.Linear(1, 1) model = nn.Sequential(OrderedDict([("a", a), ("b", b)])) - model = Pipe(model, [1, 1], devices=["cpu", "cpu"]) + model = Pipe(model) names = set(n for n, _ in model.named_modules()) assert "partitions.0.a" in names @@ -525,23 +485,9 @@ def test_named_children(): model.a -def test_recommend_auto_balance(): - with pytest.raises(ValueError, match="torch.distributed._pipeline.sync.balance"): - # balance is required - Pipe(nn.Sequential()) - - with pytest.raises(ValueError, match="torch.distributed._pipeline.sync.balance"): - # module and sum of balance have differen length (module: 0, sum of balance: 1) - Pipe(nn.Sequential(), [1]) - - with pytest.raises(ValueError, match="torch.distributed._pipeline.sync.balance"): - # module and sum of balance have different length (module: 2, sum of balance: 1) - Pipe(nn.Sequential(nn.Linear(1, 1), nn.Linear(1, 1)), [1]) - - def test_verify_module_non_sequential(): with pytest.raises(TypeError, match="module must be nn.Sequential to be partitioned"): - Pipe(nn.Module(), [1]) + Pipe(nn.Module()) def test_verify_module_duplicate_children(): @@ -549,22 +495,45 @@ def test_verify_module_duplicate_children(): model = nn.Sequential(conv, conv) with pytest.raises(ValueError, match="module with duplicate children is not supported"): - Pipe(model, [1, 1]) + Pipe(model) @skip_if_no_cuda -def test_verify_module_duplicate_parameters_on_distinct_devices(): +def test_verify_module_params_on_same_device(): class Surrogate(nn.Module): - def __init__(self, module): + def __init__(self, param1, param2): super().__init__() - self.module = module - - conv = nn.Conv2d(3, 3, 1) - model = nn.Sequential(Surrogate(conv), Surrogate(conv)) - - with pytest.raises(ValueError, match="module with duplicate parameters on distinct devices is not supported"): - Pipe(model, [1, 1], devices=["cpu", "cuda"]) + self.param1 = param1 + self.param2 = param2 + + conv1 = nn.Conv2d(3, 3, 1) + conv2 = nn.Conv2d(3, 3, 1) + model = nn.Sequential(Surrogate(conv1, conv2.cuda())) + + with pytest.raises( + ValueError, + match='should have all parameters on a single device, please use .to\(\)' + ' to place the module on a single device'): + Pipe(model) + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Need atleast two GPUs") +def test_verify_nested_modules(): + model = nn.Sequential( + nn.Sequential( + nn.Linear(32, 16).cuda(0), + nn.Linear(16, 8).cuda(0) + ), + nn.Sequential( + nn.Linear(8, 4).cuda(1), + nn.Linear(4, 2).cuda(1) + ), + ) + pipe = Pipe(model) + out = pipe(torch.rand(10, 32).cuda(0)) + assert out.device == torch.device("cuda:1") + assert out.size() == torch.Size([10, 2]) def test_verify_module_duplicate_parameters_on_same_device(): class Surrogate(nn.Module): @@ -575,7 +544,7 @@ def __init__(self, module): conv = nn.Conv2d(3, 3, 1) model = nn.Sequential(Surrogate(conv), Surrogate(conv)) - Pipe(model, [1, 1], devices=["cpu", "cpu"]) + Pipe(model) def test_forward_lockstep(): @@ -597,7 +566,7 @@ def forward(self, x): return x model = nn.Sequential(DelayedLog(0, seconds=0), DelayedLog(1, seconds=0.1)) - model = Pipe(model, balance=[1, 1], devices=["cpu", "cpu"], chunks=3) + model = Pipe(model, chunks=3) model(torch.rand(3, 1)) # Expected timeline: (Logs are recorded at !) diff --git a/test/distributed/_pipeline/sync/test_transparency.py b/test/distributed/_pipeline/sync/test_transparency.py index 88d9c83b9a07..3d2c77e8fef4 100644 --- a/test/distributed/_pipeline/sync/test_transparency.py +++ b/test/distributed/_pipeline/sync/test_transparency.py @@ -31,7 +31,7 @@ def zero_grad(parameters): zero_grad(model.parameters()) # With Pipe - model = Pipe(model, [2, 2], devices=["cpu", "cpu"], chunks=4) + model = Pipe(model, chunks=4) outputs = model(inputs) loss = outputs.mean() diff --git a/torch/distributed/_pipeline/sync/pipe.py b/torch/distributed/_pipeline/sync/pipe.py index 500b15b72771..68906958cc0e 100644 --- a/torch/distributed/_pipeline/sync/pipe.py +++ b/torch/distributed/_pipeline/sync/pipe.py @@ -65,8 +65,8 @@ def verify_module(module: nn.Sequential) -> None: raise ValueError("module with duplicate children is not supported") -def verify_splitting( - module: nn.Sequential, partitions: List[nn.Sequential], balance: Iterable[int], devices: List[torch.device] +def _verify_splitting( + module: nn.Sequential, partitions: List[nn.Sequential], devices: List[torch.device] ) -> None: num_parameters = len(list(module.parameters())) num_child_parameters = sum(len(list(child.parameters())) for child in module.children()) @@ -89,66 +89,46 @@ class BalanceError(ValueError): pass -def split_module( - module: nn.Sequential, balance: Iterable[int], devices: List[torch.device], -) -> Tuple[List[nn.Sequential], List[int], List[torch.device]]: - """Splits a module into multiple partitions. +def _retrieve_device(module: nn.Module) -> torch.device: + """Validates all parameters in the Module have the same device and returns + the appropriate device. - Returns: - A tuple of (partitions, balance, devices). + Arguments: + An ``nn.Module`` to process. - Partitions are represented as a :class:`~torch.nn.ModuleList` whose - item is a partition. All layers in a partition are placed in the - same device. + Returns: + ``torch.Device`` for the entire module. Raises: - BalanceError: - wrong balance - IndexError: - the number of devices is fewer than the number of partitions. - + ValueError: + If devices for ``nn.Module`` parameters are not all same. """ - balance = list(balance) - if len(module) != sum(balance): - raise BalanceError( - "module and sum of balance have different length " - f"(module: {len(module)}, sum of balance: {sum(balance)})" - ) + device = None + for parameter in module.parameters(): + if device is None: + device = parameter.device + elif device != parameter.device: + raise ValueError( + 'nn.Module: {}, should have all parameters on a single device,' + ' please use .to() to place the module on a single device'.format(module)) - if any(x <= 0 for x in balance): - raise BalanceError(f"all balance numbers must be positive integer (balance: {balance})") + return device if device is not None else torch.device("cpu") - if len(balance) > len(devices): - raise IndexError( - "too few devices to hold given partitions " f"(devices: {len(devices)}, partitions: {len(balance)})" - ) - - j = 0 +def _split_module(modules: nn.Sequential) -> Tuple[List[nn.Sequential], List[torch.device]]: partitions = [] - layers: NamedModules = OrderedDict() - - for name, layer in module.named_children(): - layers[name] = layer - - if len(layers) == balance[j]: - # Group buffered layers as a partition. - partition = nn.Sequential(layers) - - device = devices[j] - partition.to(device) - - partitions.append(partition) - - # Prepare for the next partition. - layers.clear() - j += 1 + devices = [] + for name, module in modules.named_children(): + devices.append(_retrieve_device(module)) + if isinstance(module, nn.Sequential): + partition = module + else: + partition = nn.Sequential(OrderedDict([(name, module)])) + partitions.append(partition) partitions = cast(List[nn.Sequential], nn.ModuleList(partitions)) - del devices[j:] - - return partitions, balance, devices + return partitions, devices MOVING_DENIED = TypeError("denied to move parameters and buffers, " "because Pipe should manage device placement") @@ -160,28 +140,27 @@ class Pipe(Module): :: model = nn.Sequential(a, b, c, d) - model = Pipe(model, balance=[1, 1, 1, 1], chunks=8) + model = Pipe(model, chunks=8) output = model(input) - .. _Pipe: https://arxiv.org/abs/1811.06965 + .. _Pipe: https://arxiv.org/abs/2004.09910 Pipe combines pipeline parallelism with checkpointing to reduce peak memory required to train while minimizing device under-utilization. - You should determine the balance when defining a :class:`Pipe` module, as - balancing will not be done automatically. The module will be partitioned - into multiple devices according to the given balance. You may rely on - heuristics to find your own optimal configuration. + You should place all the modules on the appropriate devices before passing + them to this API and wrap them into an ``nn.Sequential`` module defining the + desired order of execution. Args: - module (torch.nn.Sequential): - sequential module to be parallelized - balance (ints): - list of number of layers in each partition + module (``torch.nn.Sequential``): + Sequential module to be parallelized using pipelining. Each module + in the sequence has to have all of its parameters on a single + device. Each module in the sequence has to either be an nn.Module + or ``nn.Sequential`` (to combine multiple sequential modules on a single + device) Keyword Args: - devices (iterable of devices): - devices to use (default: all CUDA devices) chunks (int): number of micro-batches (default: ``1``) checkpoint (str): @@ -196,33 +175,12 @@ class Pipe(Module): TypeError: the module is not a :class:`nn.Sequential `. ValueError: - invalid arguments, or wrong balance + invalid arguments IndexError: the number of devices is fewer than the number of partitions. """ - #: The number of layers in each partition. - balance: List[int] = [] - # ^^ - # The default value [] required for Sphinx's autoattribute. - - #: The devices mapped to each partition. - #: - #: ``devices[-1]`` refers to the device of the last partition, which means - #: it is the output device. Probably, you need to use it to transfer the - #: target to calculate the loss without a device mismatch - #: :exc:`RuntimeError`. For example:: - #: - #: out_device = pipe.devices[-1] - #: - #: for input, target in loader: - #: target = target.to(out_device, non_blocking=True) - #: output = pipe(input) - #: loss = F.cross_entropy(output, target) - #: - devices: List[torch.device] = [] - #: The number of micro-batches. chunks: int = 1 @@ -233,9 +191,6 @@ class Pipe(Module): def __init__( self, module: nn.Sequential, - balance: Optional[Iterable[int]] = None, - *, - devices: Optional[Devices] = None, chunks: int = chunks, checkpoint: str = checkpoint, deferred_batch_norm: bool = False, @@ -245,8 +200,6 @@ def __init__( chunks = int(chunks) checkpoint = str(checkpoint) - if balance is None: - raise ValueError(recommend_auto_balance("balance is required")) if chunks <= 0: raise ValueError("number of chunks must be positive integer") if checkpoint not in ["always", "except_last", "never"]: @@ -264,17 +217,8 @@ def __init__( if deferred_batch_norm: module = DeferredBatchNorm.convert_deferred_batch_norm(module, chunks) - if devices is None: - devices = range(torch.cuda.device_count()) - devices = [torch.device(d) for d in devices] - devices = cast(List[torch.device], devices) - - try: - self.partitions, self.balance, self.devices = split_module(module, balance, devices) - except BalanceError as exc: - raise ValueError(recommend_auto_balance(str(exc))) - - verify_splitting(module, self.partitions, self.balance, self.devices) + self.partitions, self.devices = _split_module(module) + _verify_splitting(module, self.partitions, self.devices) self._copy_streams: List[List[AbstractStream]] = [] self._skip_layout = inspect_skip_layout(self.partitions) diff --git a/torch/testing/_internal/distributed/pipeline/__init__.py b/torch/testing/_internal/distributed/pipeline/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/torch/testing/_internal/distributed/pipeline/utils.py b/torch/testing/_internal/distributed/pipeline/utils.py new file mode 100644 index 000000000000..2bf4829b8223 --- /dev/null +++ b/torch/testing/_internal/distributed/pipeline/utils.py @@ -0,0 +1,21 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +from torch import nn +from typing import List + +def convert_to_balance(pipe: nn.Sequential, balance: List[int]): + device_idx = 0 + pipe_idx = 0 + balanced_pipe = [] + for num_layers in balance: + layers = [] + for i in range(num_layers): + layers.append(pipe[pipe_idx]) + pipe_idx += 1 + balanced_pipe.append(nn.Sequential(*layers).to(device_idx)) + device_idx += 1 + + return nn.Sequential(*balanced_pipe) From 21f447ee2c6ebbd72b6c3608c4df17c74edd4784 Mon Sep 17 00:00:00 2001 From: Garret Catron Date: Thu, 12 Nov 2020 14:26:59 -0800 Subject: [PATCH 78/93] Added serialization of parameters for leaf modules (#47729) Summary: This adds the serialization of parameters of leaf nodes to the json serialization. Specifically __constants__ of the leaf module is serialized as parameters in the JSON. It also adds type/shape to leaf modules as well. ``` { "shape": "[3, 3, 1, 1]", "dtype": "torch.float32", "parameters": { "name": "Conv2d", "stride": [ 1, 1 ], "padding": [ 0, 0 ], "dilation": [ 1, 1 ], "groups": 1, "padding_mode": "zeros", "output_padding": [ 0, 0 ], "in_channels": 3, "out_channels": 3, "kernel_size": [ 2, 2 ] }, "target": "conv", "op_code": "call_module", "name": "conv", "args": [ { "is_node": true, "name": "c" } ], "kwargs": {} }, ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/47729 Reviewed By: ailzhang Differential Revision: D24901632 Pulled By: gcatron fbshipit-source-id: 7f2d923937042b60819c58fd180b426a3733ff5f --- test/test_fx_experimental.py | 25 +++++++------ torch/fx/experimental/GraphManipulation.py | 43 +++++++++++++++++----- 2 files changed, 48 insertions(+), 20 deletions(-) diff --git a/test/test_fx_experimental.py b/test/test_fx_experimental.py index 884e21a17ddd..6f07e85211a4 100644 --- a/test/test_fx_experimental.py +++ b/test/test_fx_experimental.py @@ -32,10 +32,12 @@ def __init__(self): super().__init__() self.linear = torch.nn.Linear(4, 4) self.e = torch.rand(4) + self.conv = torch.nn.Conv2d(3, 3, 2, bias=False) - def forward(self, a, b): + def forward(self, a, b, c): add_1 = a + b - linear = self.linear(add_1) + conv1 = self.conv(c) + linear = self.linear(add_1 + conv1) add_2 = linear + self.e return add_2 @@ -43,7 +45,8 @@ def forward(self, a, b): traced = symbolic_trace(m) a = torch.rand(4) b = torch.rand(4) - GraphManipulation.get_size_of_all_nodes(traced, [a, b]) + c = torch.rand(3, 3, 2, 2) + GraphManipulation.get_size_of_all_nodes(traced, [a, b, c]) partitioner = Partitioner() devices = [Device("dev_0", 5000, 0), Device("dev_1", 125, 1)] @@ -66,13 +69,13 @@ def forward(self, a, b): agm1 = GraphManipulation.AcceleratedGraphModule(traced) agm2 = GraphManipulation.AcceleratedGraphModule(module_with_submodules) - assert len(agm1.weights) == 3 - assert len(agm2.weights) == 3 - assert len(agm1.serialized_graph["nodes"]) == 7 - assert len(agm1.serialized_graph["weights"]) == 3 + assert len(agm1.weights) == 4 + assert len(agm2.weights) == 4 + assert len(agm1.serialized_graph["nodes"]) == 10 + assert len(agm1.serialized_graph["weights"]) == 4 assert len(agm1.serialized_graph["modules"]) == 0 - assert len(agm2.serialized_graph["nodes"]) == 5 - assert len(agm2.serialized_graph["weights"]) == 3 + assert len(agm2.serialized_graph["nodes"]) == 6 + assert len(agm2.serialized_graph["weights"]) == 4 assert len(agm2.serialized_graph["modules"]) == 1 assert agm1.serialized_graph["weights"]["linear.weight"]["shape"] == "[4, 4]" assert ( @@ -87,8 +90,8 @@ def forward(self, a, b): assert agm1.serialized_graph["nodes"][0]["target"] == "a" assert agm1.serialized_graph["nodes"][0]["op_code"] == "placeholder" assert agm1.serialized_graph["nodes"][0]["name"] == "a" - assert agm1.serialized_graph["nodes"][2]["args"][0]["name"] == "a" - assert agm1.serialized_graph["nodes"][2]["args"][0]["is_node"] is True + assert agm1.serialized_graph["nodes"][6]["args"][0]["name"] == "add_2" + assert agm1.serialized_graph["nodes"][6]["args"][0]["is_node"] is True # Test quantization info serialization. x = torch.tensor([[-1.0, 0.0], [1.0, 2.0]]) diff --git a/torch/fx/experimental/GraphManipulation.py b/torch/fx/experimental/GraphManipulation.py index 7bd303f55d04..2eea162faedb 100644 --- a/torch/fx/experimental/GraphManipulation.py +++ b/torch/fx/experimental/GraphManipulation.py @@ -91,7 +91,7 @@ def serialize_shape(shape: torch.Size) -> str: def serialize_tensor_quantization(tensor: torch.Tensor) -> Dict[str, Any]: - scheme = {} # type: Dict[str, Any] + scheme: Dict[str, Any] = {} if tensor.is_quantized: scheme["q_scheme"] = str(tensor.qscheme()) if tensor.qscheme() in {torch.per_tensor_affine, torch.per_tensor_symmetric}: @@ -112,7 +112,7 @@ def serialize_tensor_quantization(tensor: torch.Tensor) -> Dict[str, Any]: def serialize_weight(tensor: torch.Tensor) -> Dict: - weight = {} # type: Dict[str, Any] + weight: Dict[str, Any] = {} weight["dtype"] = str(tensor.dtype) weight["is_quantized"] = tensor.is_quantized if tensor.is_quantized: @@ -121,6 +121,23 @@ def serialize_weight(tensor: torch.Tensor) -> Dict: return weight +def serialize_leaf_module( + mod: torch.nn.Module, weights_metadata: Dict, weights: Dict, name_prefix: str +) -> Dict: + parameters: Dict[str, Any] = {} + parameters["name"] = type(mod).__name__ + for name, buffer in mod.named_buffers(): + weights_metadata[f"{name_prefix}.{name}"] = serialize_weight(buffer) + weights[f"{name_prefix}.{name}"] = buffer + for name, parameter in mod.named_parameters(): + weights_metadata[f"{name_prefix}.{name}"] = serialize_weight(parameter) + weights[f"{name_prefix}.{name}"] = parameter + if isinstance(mod.__constants__, List): + for constant in mod.__constants__: + parameters[constant] = str(getattr(mod, constant)) + return parameters + + def serialize_module(fx_module: GraphModule, weights: Dict, name_prefix="") -> Dict: """Recursively Serializes a graph module (fx_module) to a dictionary which is later exported to JSON. It also adds all weights the provided weights dictionary by qualified_name. @@ -158,21 +175,24 @@ def serialize_module(fx_module: GraphModule, weights: Dict, name_prefix="") -> D q_per_channel_axis, int } """ - serialized_dict = {} # type: Dict[str, Any] + serialized_dict: Dict[str, Any] = {} serialized_dict["modules"] = {} serialized_dict["weights"] = {} serialized_dict["nodes"] = [] parameters = fx_module.named_parameters() + prefix = f"{name_prefix}." if name_prefix else "" + submodules = dict(fx_module.named_modules()) for name, p in parameters: if isinstance(p, torch.Tensor): weight = serialize_weight(p) - prefix = f"{name_prefix}." if name_prefix else "" serialized_dict["weights"][prefix + name] = weight weights[prefix + name] = p for node in fx_module.graph.nodes: - node_rep = {} # type: Dict[str, Any] + node_rep: Dict[str, Any] = {} # Get shape/type info, currently not needed for call_module. - if node.op != "call_module": + if node.op != "call_module" or not isinstance( + submodules[node.target], GraphModule + ): shape = getattr(node, "shape", None) if shape: node_rep["shape"] = serialize_shape(shape) @@ -190,12 +210,18 @@ def serialize_module(fx_module: GraphModule, weights: Dict, name_prefix="") -> D # Recurse down into any submodules we are calling. if node.op == "call_module": - submodules = dict(fx_module.named_modules()) if isinstance(submodules[node.target], GraphModule): serialized_module = serialize_module( getattr(fx_module, node.target), weights, node.target ) serialized_dict["modules"][node.target] = serialized_module + else: + node_rep["parameters"] = serialize_leaf_module( + submodules[node.target], + serialized_dict["weights"], + weights, + prefix + node.target, + ) if node.op == "call_function": node_rep["target"] = get_qualified_name(node.target) @@ -205,7 +231,6 @@ def serialize_module(fx_module: GraphModule, weights: Dict, name_prefix="") -> D # Make sure we capture all constants. if node.op == "get_attr": target = getattr(fx_module, node.target) - prefix = f"{name_prefix}." if name_prefix else "" qualname = prefix + node.target if isinstance(target, torch.Tensor) and qualname not in weights: weight = serialize_weight(target) @@ -228,6 +253,6 @@ def serialize_module(fx_module: GraphModule, weights: Dict, name_prefix="") -> D class AcceleratedGraphModule: def __init__(self, fx_module: GraphModule): """Creates the needed data structures to pass to the glow runtime""" - self.weights = {} # type: Dict[str, Any] + self.weights: Dict[str, Any] = {} self.serialized_graph = serialize_module(fx_module, self.weights) self.serialized_graph_json = json.dumps(self.serialized_graph, indent=4) From 9734c042b8748c897fa04ded58757c2b66f3b309 Mon Sep 17 00:00:00 2001 From: James Reed Date: Thu, 12 Nov 2020 15:52:09 -0800 Subject: [PATCH 79/93] [FX] Fix submodule naming for subgraph split (#47869) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47869 Test Plan: Imported from OSS Reviewed By: scottxu0730 Differential Revision: D24925283 Pulled By: jamesr66a fbshipit-source-id: a33bff20667405a3bbfc81e1e640c2649c0db03b --- test/test_fx.py | 38 ------------- test/test_fx_experimental.py | 56 +++++++++++++++++++ .../experimental/subgraph_creation_example.py | 3 +- 3 files changed, 58 insertions(+), 39 deletions(-) diff --git a/test/test_fx.py b/test/test_fx.py index b207f11b1d80..4d3a86205e37 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -11,7 +11,6 @@ from torch.fx import symbolic_trace, Proxy, Node, GraphModule, Tracer, Graph from torch.fx.experimental import GraphManipulation from torch.fx.experimental import shape_prop -from torch.fx.experimental.subgraph_creation_example import split_module from torch.fx.immutable_collections import immutable_dict, immutable_list from copy import deepcopy @@ -892,43 +891,6 @@ def test_inf_nan_kwds(self): x = torch.rand(3, 4) self.assertEqual(gm(x), (x + float('inf'), x + float('nan'))) - def test_subgraph_creation(self): - class MyModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.param = torch.nn.Parameter(torch.rand(3, 4)) - self.linear = torch.nn.Linear(4, 5) - - def forward(self, x, y): - z = self.linear(x + self.param).clamp(min=0.0, max=1.0) - w = self.linear(y).clamp(min=0.0, max=1.0) - return z + w - - # symbolically trace model - my_module = MyModule() - my_module_traced = symbolic_trace(my_module) - - # random mod partitioning - partition_counter = 0 - NPARTITIONS = 3 - - def mod_partition(node: Node): - nonlocal partition_counter - partition = partition_counter % NPARTITIONS - partition_counter = (partition_counter + 1) % NPARTITIONS - return partition - - # split module in module with submodules - module_with_submodules = split_module(my_module_traced, my_module, mod_partition) - - x = torch.rand(3, 4) - y = torch.rand(3, 4) - - orig_out = my_module_traced(x, y) - submodules_out = module_with_submodules(x, y) - - self.assertEqual(orig_out, submodules_out) - def test_deepcopy_recursion_depth(self): depth = sys.getrecursionlimit() + 20 diff --git a/test/test_fx_experimental.py b/test/test_fx_experimental.py index 6f07e85211a4..11a3ea2006fd 100644 --- a/test/test_fx_experimental.py +++ b/test/test_fx_experimental.py @@ -1,4 +1,5 @@ import torch +import unittest from typing import Dict from torch.fx.symbolic_trace import symbolic_trace from torch.fx.graph_module import GraphModule @@ -8,6 +9,7 @@ from torch.fx.experimental.rewriter import RewritingTracer from torch.testing._internal.common_utils import run_tests from torch.testing._internal.jit_utils import JitTestCase +from torch.fx.experimental.subgraph_creation_example import split_module from torch.fx.experimental.partitioner_utils import ( NodeLatency, get_partition_to_latency_mapping, @@ -17,6 +19,13 @@ ) from typing import Union, Callable +try: + from torchvision.models import resnet18 + HAS_TORCHVISION = True +except ImportError: + HAS_TORCHVISION = False +skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision") + def symbolic_trace_with_rewrite(root: Union[torch.nn.Module, Callable]) -> GraphModule: return GraphModule( @@ -488,6 +497,53 @@ def forward(self, a, b): # Confirm that the output is correct self.assertEqual(traced(3, 3), m(3, 3)) + def test_subgraph_creation(self): + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.param = torch.nn.Parameter(torch.rand(3, 4)) + self.linear = torch.nn.Linear(4, 5) + + def forward(self, x, y): + z = self.linear(x + self.param).clamp(min=0.0, max=1.0) + w = self.linear(y).clamp(min=0.0, max=1.0) + return z + w + + # symbolically trace model + my_module = MyModule() + my_module_traced = symbolic_trace(my_module) + + # random mod partitioning + partition_counter = 0 + NPARTITIONS = 3 + + def mod_partition(node: Node): + nonlocal partition_counter + partition = partition_counter % NPARTITIONS + partition_counter = (partition_counter + 1) % NPARTITIONS + return partition + + # split module in module with submodules + module_with_submodules = split_module(my_module_traced, my_module, mod_partition) + + x = torch.rand(3, 4) + y = torch.rand(3, 4) + + orig_out = my_module_traced(x, y) + submodules_out = module_with_submodules(x, y) + + self.assertEqual(orig_out, submodules_out) + + @skipIfNoTorchVision + def test_subgraph_trivial_resnet(self): + # Smoke test trivially splitting resnet into 1 partition works + # There was an issue before causing submodule names to be aliased + m = resnet18() + traced = symbolic_trace(m) + a = torch.rand(64, 3, 7, 7) + module_with_submodules = split_module(traced, m, lambda node: 0) + module_with_submodules(a) + def test_traceable_function_with_nonstandard_name(self): def foo(x): return torch.relu(x) diff --git a/torch/fx/experimental/subgraph_creation_example.py b/torch/fx/experimental/subgraph_creation_example.py index 930e8f35426e..526259861ccf 100644 --- a/torch/fx/experimental/subgraph_creation_example.py +++ b/torch/fx/experimental/subgraph_creation_example.py @@ -115,7 +115,8 @@ def record_cross_partition_use(def_node : torch.fx.node.Node, use_node : Optiona if not hasattr(target_attr, atom): raise RuntimeError(f'Operator target {node.target} not found!') target_attr = getattr(target_attr, atom) - target = target_atoms[-1] + # target = target_atoms[-1] + target = '_'.join(target_atoms) partition.targets[target] = target_attr assert isinstance(gathered_args, tuple) From 7391edb591ae3c9675063ad0fa8d2c5c62e24be1 Mon Sep 17 00:00:00 2001 From: Rong Rong Date: Thu, 12 Nov 2020 15:57:08 -0800 Subject: [PATCH 80/93] [hotfix] fix misleadingly summary BLAS=MKL when there's no BLAS install (#47803) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47803 Reviewed By: samestep Differential Revision: D24907453 Pulled By: walterddr fbshipit-source-id: a3e41041f6aa506b054eb0ffc61f8525ba02cbf1 --- cmake/Dependencies.cmake | 2 ++ cmake/Summary.cmake | 9 ++++++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 742c87b09233..69f4ff23467e 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -1643,6 +1643,8 @@ if(NOT INTERN_BUILD_MOBILE) if(LAPACK_FOUND) set(USE_LAPACK 1) list(APPEND Caffe2_PRIVATE_DEPENDENCY_LIBS ${LAPACK_LIBRARIES}) + else() + set(USE_LAPACK 0) endif() if(NOT USE_CUDA) diff --git a/cmake/Summary.cmake b/cmake/Summary.cmake index 1e02114b2dce..5f5473c2b056 100644 --- a/cmake/Summary.cmake +++ b/cmake/Summary.cmake @@ -9,7 +9,6 @@ function(caffe2_print_configuration_summary) message(STATUS " C++ compiler : ${CMAKE_CXX_COMPILER}") message(STATUS " C++ compiler id : ${CMAKE_CXX_COMPILER_ID}") message(STATUS " C++ compiler version : ${CMAKE_CXX_COMPILER_VERSION}") - message(STATUS " BLAS : ${BLAS}") message(STATUS " CXX flags : ${CMAKE_CXX_FLAGS}") message(STATUS " Build type : ${CMAKE_BUILD_TYPE}") get_directory_property(tmp DIRECTORY ${PROJECT_SOURCE_DIR} COMPILE_DEFINITIONS) @@ -51,6 +50,14 @@ function(caffe2_print_configuration_summary) message(STATUS " INTERN_BUILD_MOBILE : ${INTERN_BUILD_MOBILE}") + message(STATUS " USE_BLAS : ${USE_BLAS}") + if(${USE_BLAS}) + message(STATUS " BLAS : ${BLAS_INFO}") + endif() + message(STATUS " USE_LAPACK : ${USE_LAPACK}") + if(${USE_LAPACK}) + message(STATUS " LAPACK : ${LAPACK_INFO}") + endif() message(STATUS " USE_ASAN : ${USE_ASAN}") message(STATUS " USE_CPP_CODE_COVERAGE : ${USE_CPP_CODE_COVERAGE}") message(STATUS " USE_CUDA : ${USE_CUDA}") From 3649a2c170c45653d2aa1267d48beb867914b039 Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Thu, 12 Nov 2020 16:14:11 -0800 Subject: [PATCH 81/93] [numpy] `torch.sqrt` : promote integer inputs to float (#47293) Summary: Reference https://github.com/pytorch/pytorch/issues/42515 Pull Request resolved: https://github.com/pytorch/pytorch/pull/47293 Reviewed By: malfet Differential Revision: D24855994 Pulled By: mruberry fbshipit-source-id: 1e6752f2eeba6d638dea0bdea0c650cf722718c9 --- aten/src/ATen/native/UnaryOps.cpp | 4 +-- aten/src/ATen/native/cuda/UnaryOpsKernel.cu | 4 +-- test/test_torch.py | 3 +- torch/csrc/jit/tensorexpr/kernel.cpp | 5 +-- .../_internal/common_methods_invocations.py | 32 ++++++++++++++++++- 5 files changed, 39 insertions(+), 9 deletions(-) diff --git a/aten/src/ATen/native/UnaryOps.cpp b/aten/src/ATen/native/UnaryOps.cpp index 7d57651005d5..8e01aff472ff 100644 --- a/aten/src/ATen/native/UnaryOps.cpp +++ b/aten/src/ATen/native/UnaryOps.cpp @@ -378,8 +378,8 @@ Tensor& arctanh_out(Tensor& result, const Tensor& self) { return at::atanh_out(r Tensor arctanh(const Tensor& self) { return self.atanh(); } Tensor& arctanh_(Tensor& self) { return self.atanh_(); } -Tensor& sqrt_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, sqrt_stub); } -Tensor sqrt(const Tensor& self) { return unary_op_impl(self, at::sqrt_out); } +Tensor& sqrt_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, sqrt_stub); } +Tensor sqrt(const Tensor& self) { return unary_op_impl_float(self, sqrt_stub); } Tensor& sqrt_(Tensor& self) { return unary_op_impl_(self, at::sqrt_out); } Tensor square(const Tensor& self) { return at::pow(self, 2); } diff --git a/aten/src/ATen/native/cuda/UnaryOpsKernel.cu b/aten/src/ATen/native/cuda/UnaryOpsKernel.cu index 4b1f0c1a6aa3..25dbf5a1a6ef 100644 --- a/aten/src/ATen/native/cuda/UnaryOpsKernel.cu +++ b/aten/src/ATen/native/cuda/UnaryOpsKernel.cu @@ -78,7 +78,7 @@ __host__ __device__ static inline c10::complex rsqrt_wrapper(c10::complex } void rsqrt_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(ScalarType::Half, iter.dtype(), "rsqrt_cuda", [&]() { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(ScalarType::Half, iter.common_dtype(), "rsqrt_cuda", [&]() { gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { // In CUDA, ::rsqrt is overloaded for float and at::Half here is implicitly cast to float. return rsqrt_wrapper(a); @@ -87,7 +87,7 @@ void rsqrt_kernel_cuda(TensorIterator& iter) { } void sqrt_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "sqrt_cuda", [&]() { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "sqrt_cuda", [&]() { gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { return ::sqrt(a); }); diff --git a/test/test_torch.py b/test/test_torch.py index 6bcbd5582dc8..ddde396c011b 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -20921,8 +20921,7 @@ def __init__(self, self.dtypes = dtypes self.replace_inf_with_nan = replace_inf_with_nan -torch_op_tests = [_TorchMathTestMeta('sqrt'), - _TorchMathTestMeta('erf', ref_backend='scipy'), +torch_op_tests = [_TorchMathTestMeta('erf', ref_backend='scipy'), _TorchMathTestMeta('erfc', ref_backend='scipy'), _TorchMathTestMeta('exp'), _TorchMathTestMeta('expm1'), diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp index 24bfedc92841..83ecb69774ae 100644 --- a/torch/csrc/jit/tensorexpr/kernel.cpp +++ b/torch/csrc/jit/tensorexpr/kernel.cpp @@ -1146,8 +1146,9 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) { } break; case aten::sqrt: { - return computeOneOperand( - "aten_sqrt", v, [](const ExprHandle& a) { return sqrt(a); }); + return computeOneOperand("aten_sqrt", v, [](const ExprHandle& a) { + return sqrt(promoteIntegerToFloat(a)); + }); } break; case aten::rsqrt: { diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index e62d62eb12bf..8918d3ea8c1f 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -436,7 +436,37 @@ def sample_inputs(self, device, dtype, requires_grad=False): ref=np.nan_to_num, dtypes=all_types_and(torch.half, torch.bool), dtypesIfCPU=None, - dtypesIfCUDA=None) + dtypesIfCUDA=None), + UnaryUfuncInfo('sqrt', + ref=np.sqrt, + domain=(0, float('inf')), + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16), + dtypesIfCPU=all_types_and_complex_and(torch.bool, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + decorators=(precisionOverride({torch.bfloat16: 7e-2}),), + skips=( + # Reference: https://github.com/pytorch/pytorch/issues/47358 + SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], + active_if=IS_MACOS), + # Reference: https://github.com/pytorch/pytorch/pull/47293#issuecomment-721774436 + SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', + dtypes=[torch.bfloat16]), + # RuntimeError: sqrt does not support automatic differentiation for outputs with complex dtype. + SkipInfo('TestGradients', 'test_fn_grad', + dtypes=[torch.cdouble]), + SkipInfo('TestGradients', 'test_fn_gradgrad', + dtypes=[torch.cdouble]), + SkipInfo('TestGradients', 'test_method_grad', + dtypes=[torch.cdouble]), + SkipInfo('TestGradients', 'test_method_gradgrad', + dtypes=[torch.cdouble]), + SkipInfo('TestGradients', 'test_inplace_grad', + dtypes=[torch.cdouble]), + SkipInfo('TestGradients', 'test_inplace_gradgrad', + dtypes=[torch.cdouble]),), + promotes_integers_to_float=True, + handles_complex_extremals=False), ] # Common operator groupings From edf751ca2fededecdd9366874c761431c0f61f01 Mon Sep 17 00:00:00 2001 From: Sebastian Messmer Date: Thu, 12 Nov 2020 17:05:52 -0800 Subject: [PATCH 82/93] Make empty c10-full (#46092) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/46092 Make empty c10-full without using hacky-wrapper, i.e. port the kernel to the new style signature. This PR also changes the signature of some helpers called by empty to the new style. ghstack-source-id: 116544203 (Note: this ignores all push blocking failures!) Test Plan: vs prev diff (outdated, before c10::optional fix): https://www.internalfb.com/intern/fblearner/details/224735103/ after c10::optional fix: https://www.internalfb.com/intern/fblearner/details/231391773/ Also, after the c10::optional fix, the instruction counting benchmark shows a 2% regression for calling empty from Python. We decided this is acceptable and decided against landing D24425836 which would fix the regression. Reviewed By: ezyang Differential Revision: D24219944 fbshipit-source-id: e554096e90ce438c75b679131c3151ff8e5c5d50 --- aten/src/ATen/BatchingRegistrations.cpp | 7 +++- aten/src/ATen/ScalarOps.cpp | 5 ++- aten/src/ATen/ScalarOps.h | 10 +++-- aten/src/ATen/Utils.cpp | 25 +++++-------- aten/src/ATen/Utils.h | 6 +-- aten/src/ATen/native/MetaTensor.cpp | 23 +++++------- aten/src/ATen/native/TensorFactories.cpp | 37 ++++++++++++------- aten/src/ATen/native/TensorFactories.h | 8 ++-- aten/src/ATen/native/cuda/Indexing.cu | 5 ++- .../src/ATen/native/cuda/MultinomialKernel.cu | 4 +- aten/src/ATen/native/cuda/ScanKernels.cu | 3 +- aten/src/ATen/native/cuda/TensorFactories.cu | 37 +++++++++---------- aten/src/ATen/native/metal/MetalAten.mm | 13 ++++--- aten/src/ATen/native/mkldnn/BinaryOps.cpp | 7 +++- aten/src/ATen/native/mkldnn/Conv.cpp | 15 +++++--- aten/src/ATen/native/mkldnn/Linear.cpp | 6 ++- aten/src/ATen/native/mkldnn/MKLDNNCommon.cpp | 6 ++- aten/src/ATen/native/mkldnn/MKLDNNCommon.h | 2 +- .../ATen/native/mkldnn/MKLDNNConversions.cpp | 9 +++-- aten/src/ATen/native/mkldnn/Normalization.cpp | 18 ++++++--- aten/src/ATen/native/mkldnn/Pooling.cpp | 2 +- aten/src/ATen/native/mkldnn/Relu.cpp | 3 +- aten/src/ATen/native/mkldnn/SoftMax.cpp | 3 +- .../ATen/native/mkldnn/TensorFactories.cpp | 9 ++--- aten/src/ATen/native/mkldnn/TensorShape.cpp | 9 +++-- aten/src/ATen/native/mkldnn/UnaryOps.cpp | 3 +- aten/src/ATen/native/native_functions.yaml | 16 ++++---- aten/src/ATen/native/sparse/SparseTensor.cpp | 34 ++++++++++------- aten/src/ATen/native/vulkan/VulkanAten.cpp | 15 +++++--- aten/src/ATen/test/basic.cpp | 10 +---- aten/src/ATen/test/extension_backend_test.cpp | 11 +++--- c10/core/TensorOptions.h | 14 ++++--- c10/util/Optional.h | 18 +++++++++ test/cpp_extensions/msnpu_extension.cpp | 5 ++- tools/autograd/gen_variable_type.py | 1 + torch/csrc/jit/codegen/cuda/executor.cpp | 25 +++++++------ 36 files changed, 243 insertions(+), 181 deletions(-) diff --git a/aten/src/ATen/BatchingRegistrations.cpp b/aten/src/ATen/BatchingRegistrations.cpp index c30ddb631d0a..1efc911aeb92 100644 --- a/aten/src/ATen/BatchingRegistrations.cpp +++ b/aten/src/ATen/BatchingRegistrations.cpp @@ -869,10 +869,13 @@ Tensor new_zeros_batching_rule( Tensor new_empty_batching_rule( const Tensor& self, IntArrayRef size, - const TensorOptions& options) { + c10::optional dtype, + c10::optional layout, + c10::optional device, + c10::optional pin_memory) { auto physical_view = MultiBatchVmapTransform::logicalToPhysical(self); auto physical_size = physical_view.getPhysicalShape(size); - auto result = physical_view.tensor().new_empty(physical_size, options); + auto result = physical_view.tensor().new_empty(physical_size, TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory)); return physical_view.newLogicalFromPhysical(result); } diff --git a/aten/src/ATen/ScalarOps.cpp b/aten/src/ATen/ScalarOps.cpp index 7a794cb5c312..26efc74e8c2b 100644 --- a/aten/src/ATen/ScalarOps.cpp +++ b/aten/src/ATen/ScalarOps.cpp @@ -29,10 +29,11 @@ Tensor& scalar_fill(Tensor& self, Scalar value) { return self; } -Tensor scalar_tensor_static(Scalar s, const TensorOptions& options) { +Tensor scalar_tensor_static(Scalar s, c10::optional dtype_opt, c10::optional layout_opt, + c10::optional device_opt, c10::optional pin_memory_opt, c10::optional memory_format_opt) { at::tracer::impl::NoTracerDispatchMode tracer_guard; at::AutoNonVariableTypeMode non_var_type_mode(true); - auto result = at::detail::empty_cpu({}, options); + auto result = at::detail::empty_cpu({}, dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt); scalar_fill(result, s); return result; } diff --git a/aten/src/ATen/ScalarOps.h b/aten/src/ATen/ScalarOps.h index 60cee3ea284b..b081183414a5 100644 --- a/aten/src/ATen/ScalarOps.h +++ b/aten/src/ATen/ScalarOps.h @@ -12,7 +12,9 @@ namespace detail { // but we also want to skip compute_types which in not avoidable // in TensorIterator for now. Tensor& scalar_fill(Tensor& self, Scalar value); -TORCH_API Tensor scalar_tensor_static(Scalar s, const TensorOptions& options); +TORCH_API Tensor scalar_tensor_static(Scalar s, c10::optional dtype_opt, c10::optional layout_opt, + c10::optional device_opt, c10::optional pin_memory_opt, + c10::optional memory_format_opt); } // namespace detail } // namespace at @@ -25,12 +27,12 @@ inline at::Tensor scalar_to_tensor(Scalar s, const Device device = at::kCPU) { // This is the fast track we have for CPU scalar tensors. if (device == at::kCPU && !s.isComplex()) { if (s.isFloatingPoint()) { - return at::detail::scalar_tensor_static(s, at::device(at::kCPU).dtype(at::kDouble)); + return at::detail::scalar_tensor_static(s, at::kDouble, c10::nullopt, at::kCPU, c10::nullopt, c10::nullopt); } else if (s.isBoolean()) { - return at::detail::scalar_tensor_static(s, at::device(at::kCPU).dtype(at::kBool)); + return at::detail::scalar_tensor_static(s, at::kBool, c10::nullopt, at::kCPU, c10::nullopt, c10::nullopt); } else { AT_ASSERT(s.isIntegral(false)); - return at::detail::scalar_tensor_static(s, at::device(at::kCPU).dtype(at::kLong)); + return at::detail::scalar_tensor_static(s, at::kLong, c10::nullopt, at::kCPU, c10::nullopt, c10::nullopt); } } if (s.isFloatingPoint()) { diff --git a/aten/src/ATen/Utils.cpp b/aten/src/ATen/Utils.cpp index 8a4fa37e469e..9e1c33a4dbb9 100644 --- a/aten/src/ATen/Utils.cpp +++ b/aten/src/ATen/Utils.cpp @@ -16,32 +16,24 @@ int _crash_if_asan(int arg) { namespace detail { // empty_cpu is used in ScalarOps.h, which can be referenced by other ATen files. Since we want to decouple direct referencing native symbols and only access native symbols through dispatching, we move its implementation here. -Tensor empty_cpu( - IntArrayRef size, - const TensorOptions& options, - c10::optional optional_memory_format) { - TORCH_CHECK( - !(options.has_memory_format() && optional_memory_format.has_value()), - "Cannot set memory_format both in TensorOptions and explicit argument; please delete " - "the redundant setter."); - const MemoryFormat memory_format = - optional_memory_format.value_or( - options.memory_format_opt().value_or( - MemoryFormat::Contiguous)); +Tensor empty_cpu(IntArrayRef size, c10::optional dtype_opt, c10::optional layout_opt, + c10::optional device_opt, c10::optional pin_memory_opt, c10::optional memory_format_opt) { + Device device = device_or_default(device_opt); - AT_ASSERT(options.device().type() == DeviceType::CPU); + TORCH_CHECK(device.type() == DeviceType::CPU); check_size_nonnegative(size); + bool pin_memory = pinned_memory_or_default(pin_memory_opt); c10::Allocator* allocator; - if (options.pinned_memory()) { + if (pin_memory) { allocator = detail::getCUDAHooks().getPinnedMemoryAllocator(); } else { allocator = at::getCPUAllocator(); } int64_t nelements = prod_intlist(size); - const caffe2::TypeMeta dtype = options.dtype(); - const int64_t size_bytes = nelements * dtype.itemsize(); + caffe2::TypeMeta dtype = scalarTypeToTypeMeta(dtype_or_default(dtype_opt)); + int64_t size_bytes = nelements * dtype.itemsize(); auto storage_impl = c10::make_intrusive( c10::StorageImpl::use_byte_size_t(), size_bytes, @@ -56,6 +48,7 @@ Tensor empty_cpu( tensor.unsafeGetTensorImpl()->set_sizes_contiguous(size); } + auto memory_format = memory_format_opt.value_or(MemoryFormat::Contiguous); tensor.unsafeGetTensorImpl()->empty_tensor_restride(memory_format); return tensor; diff --git a/aten/src/ATen/Utils.h b/aten/src/ATen/Utils.h index 4fe4b632362b..14e8fa49c1b1 100644 --- a/aten/src/ATen/Utils.h +++ b/aten/src/ATen/Utils.h @@ -136,10 +136,8 @@ inline void check_size_nonnegative(IntArrayRef size) { namespace detail { CAFFE2_API -Tensor empty_cpu( - IntArrayRef size, - const TensorOptions& options = {}, - c10::optional memory_format = c10::nullopt); +Tensor empty_cpu(IntArrayRef size, c10::optional dtype_opt, c10::optional layout_opt, + c10::optional device_opt, c10::optional pin_memory_opt, c10::optional memory_format_opt); } // namespace detail } // at diff --git a/aten/src/ATen/native/MetaTensor.cpp b/aten/src/ATen/native/MetaTensor.cpp index a7042b283c4c..af293d7ebe21 100644 --- a/aten/src/ATen/native/MetaTensor.cpp +++ b/aten/src/ATen/native/MetaTensor.cpp @@ -7,34 +7,29 @@ namespace native { // Will be promoted to a public API later, but not now Tensor empty_meta( IntArrayRef size, - const TensorOptions& options_, - c10::optional optional_memory_format + c10::optional dtype, + c10::optional layout, + c10::optional device, + c10::optional pin_memory, + c10::optional memory_format ) { - TORCH_CHECK( - !(options_.has_memory_format() && optional_memory_format.has_value()), - "Cannot set memory_format both in TensorOptions and explicit argument; please delete " - "the redundant setter."); - TensorOptions options = options_.merge_memory_format(optional_memory_format); - // TODO: deduplicate this logic with empty_cpu - auto dtype = options.dtype(); - auto device = options.device(); auto tensor = detail::make_tensor( // NB: We include the computed dispatch key, not because it will actually // participate in dispatch, but so that tests like is_sparse/is_cuda // give the correct result (a CUDA meta tensor "is cuda"). If we don't // like this, remove the computeDispatchKey line - DispatchKeySet{DispatchKey::Meta, options.computeDispatchKey()}, - dtype, + DispatchKeySet{DispatchKey::Meta, computeDispatchKey(dtype, layout, device)}, + scalarTypeToTypeMeta(dtype_or_default(dtype)), device ); if (size.size() != 1 || size[0] != 0) { tensor.unsafeGetTensorImpl()->set_sizes_contiguous(size); } - auto memory_format = options.memory_format_opt().value_or(MemoryFormat::Contiguous); - tensor.unsafeGetTensorImpl()->empty_tensor_restride(memory_format); + auto memory_format_ = memory_format.value_or(MemoryFormat::Contiguous); + tensor.unsafeGetTensorImpl()->empty_tensor_restride(memory_format_); return tensor; } diff --git a/aten/src/ATen/native/TensorFactories.cpp b/aten/src/ATen/native/TensorFactories.cpp index 1e5bf2ebc6fa..42d98336e5cd 100644 --- a/aten/src/ATen/native/TensorFactories.cpp +++ b/aten/src/ATen/native/TensorFactories.cpp @@ -165,8 +165,9 @@ Tensor polar(const Tensor& abs, const Tensor& angle) { } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ empty ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Tensor empty_cpu(IntArrayRef size, const TensorOptions& options_, c10::optional optional_memory_format) { - return at::detail::empty_cpu(size, options_, optional_memory_format); +Tensor empty_cpu(IntArrayRef size, c10::optional dtype_opt, c10::optional layout_opt, + c10::optional device_opt, c10::optional pin_memory_opt, c10::optional memory_format_opt) { + return at::detail::empty_cpu(size, dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt); } Tensor empty( @@ -186,9 +187,10 @@ Tensor empty( return result; } -Tensor empty_strided_cpu(IntArrayRef size, IntArrayRef stride, const TensorOptions& options) { +Tensor empty_strided_cpu(IntArrayRef size, IntArrayRef stride, c10::optional dtype_opt, + c10::optional layout_opt, c10::optional device_opt, c10::optional pin_memory_opt) { check_size_nonnegative(size); - auto t = at::native::empty_cpu({0}, options); + auto t = at::native::empty_cpu({0}, dtype_opt, layout_opt, device_opt, pin_memory_opt); at::native::resize_impl_cpu_(t.unsafeGetTensorImpl(), size, stride); return t; } @@ -336,9 +338,16 @@ Tensor empty_like( Tensor new_empty( const Tensor& self, IntArrayRef size, - const TensorOptions& options + c10::optional dtype_opt, + c10::optional layout_opt, + c10::optional device_opt, + c10::optional pin_memory_opt ) { - return at::empty(size, self.options().merge_in(options)); + auto dtype = dtype_opt.has_value() ? dtype_opt : optTypeMetaToScalarType(self.options().dtype_opt()); + auto layout = layout_opt.has_value() ? layout_opt : self.options().layout_opt(); + auto device = device_opt.has_value() ? device_opt : self.options().device_opt(); + auto pin_memory = pin_memory_opt.has_value() ? pin_memory_opt : self.options().pinned_memory_opt(); + return at::empty(size, dtype, layout, device, pin_memory, c10::nullopt); } Tensor new_empty_strided( @@ -507,7 +516,7 @@ Tensor scalar_tensor(Scalar s, const TensorOptions& options) { // auto result = at::empty({}, options); at::tracer::impl::NoTracerDispatchMode tracer_guard; at::AutoNonVariableTypeMode non_var_type_mode(true); - auto result = empty_cpu({}, options); + auto result = empty_cpu({}, optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); at::native::fill_(result, s); return result; } @@ -735,13 +744,14 @@ Tensor range( // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triangle ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Tensor tril_indices_cpu( - int64_t row, int64_t col, int64_t offset, const TensorOptions& options) { - check_args(row, col, options); + int64_t row, int64_t col, int64_t offset, c10::optional dtype_opt, + c10::optional layout_opt, c10::optional device_opt, c10::optional pin_memory_opt) { + check_args(row, col, layout_opt); auto tril_size = get_tril_size(row, col, offset); // create an empty Tensor with correct size - auto result = at::empty({2, tril_size}, options); + auto result = at::native::empty_cpu({2, tril_size}, dtype_opt, layout_opt, device_opt, pin_memory_opt); // The following three approaches result in very little performance // differences. Hence, the 2nd option is taken for simpler code, and to return @@ -780,13 +790,14 @@ Tensor tril_indices_cpu( } Tensor triu_indices_cpu( - int64_t row, int64_t col, int64_t offset, const TensorOptions& options) { - check_args(row, col, options); + int64_t row, int64_t col, int64_t offset, c10::optional dtype_opt, + c10::optional layout_opt, c10::optional device_opt, c10::optional pin_memory_opt) { + check_args(row, col, layout_opt); auto triu_size = row * col - get_tril_size(row, col, offset - 1); // create an empty Tensor with correct size - auto result = at::empty({2, triu_size}, options); + auto result = at::native::empty_cpu({2, triu_size}, dtype_opt, layout_opt, device_opt, pin_memory_opt); AT_DISPATCH_ALL_TYPES(result.scalar_type(), "triu_indices", [&]() -> void { // fill the Tensor with correct values diff --git a/aten/src/ATen/native/TensorFactories.h b/aten/src/ATen/native/TensorFactories.h index 8cae202efe13..579cfdb624e7 100644 --- a/aten/src/ATen/native/TensorFactories.h +++ b/aten/src/ATen/native/TensorFactories.h @@ -50,14 +50,14 @@ inline int64_t get_tril_size(int64_t row, int64_t col, int64_t offset) { } inline void check_args( - int64_t row, int64_t col, const TensorOptions& options) { + int64_t row, int64_t col, c10::optional layout_opt) { TORCH_CHECK(row >= 0, "row must be non-negative, got", row); TORCH_CHECK(col >= 0, "col must be non-negative, got", col); - if (options.has_layout()) { + if (layout_opt.has_value()) { TORCH_CHECK( - options.layout() == at::kStrided, + *layout_opt == at::kStrided, "only support layout=torch.strided, got", - options.layout()) + *layout_opt) } } diff --git a/aten/src/ATen/native/cuda/Indexing.cu b/aten/src/ATen/native/cuda/Indexing.cu index d3c0b8d29553..47527935fe73 100644 --- a/aten/src/ATen/native/cuda/Indexing.cu +++ b/aten/src/ATen/native/cuda/Indexing.cu @@ -882,7 +882,8 @@ void nonzero_cuda_out_impl(const Tensor& self, Tensor& out){ //However, out with correct sizes and incorrect strides will have to be copied to from the intermediate we've produced. bool need_to_copy = out.dim() == 2 && out.sizes()[0] == num_nonzeros_h && out.sizes()[1] == self.dim() && !out.t().is_contiguous(); at::Tensor out_temp = need_to_copy ? - at::native::empty_cuda({self.dim(), num_nonzeros_h}, out.options()) : + at::native::empty_cuda({self.dim(), num_nonzeros_h}, optTypeMetaToScalarType(out.options().dtype_opt()), + out.options().layout_opt(), out.options().device_opt(), out.options().pinned_memory_opt()) : out.resize_({self.dim(), num_nonzeros_h}); //Scalars are expected to produce output of size (1,0), so we can't write to it if (self.dim() > 0) { @@ -931,7 +932,7 @@ Tensor& nonzero_out_cuda(Tensor& out, const Tensor& self){ } Tensor nonzero_cuda(const Tensor& self){ - Tensor out = at::native::empty_cuda({0}, self.options().dtype(kLong)); + Tensor out = at::native::empty_cuda({0}, kLong, self.options().layout_opt(), self.options().device_opt(), self.options().pinned_memory_opt()); return nonzero_out_cuda(out, self); } diff --git a/aten/src/ATen/native/cuda/MultinomialKernel.cu b/aten/src/ATen/native/cuda/MultinomialKernel.cu index 4e96c2868336..b828e47e8461 100644 --- a/aten/src/ATen/native/cuda/MultinomialKernel.cu +++ b/aten/src/ATen/native/cuda/MultinomialKernel.cu @@ -322,7 +322,9 @@ void multinomial_kernel_impl(Tensor& result, const Tensor& self, const int64_t n // To exploit greater parallelism for the sampling, generate the // Uniform random samples in a separate kernel launch, into // temporarily allocated memory. The device RNG is thread-limited - Tensor sampled = native::empty_cuda({numDist, n_sample}, self_v.options()); + Tensor sampled = native::empty_cuda({numDist, n_sample}, optTypeMetaToScalarType(self_v.options().dtype_opt()), + self_v.options().layout_opt(), self_v.options().device_opt(), + self_v.options().pinned_memory_opt()); at::native::uniform_(sampled, 0.0, 1.0, generator); dim3 block(numCategories < maxThreads ? numCategories : maxThreads); diff --git a/aten/src/ATen/native/cuda/ScanKernels.cu b/aten/src/ATen/native/cuda/ScanKernels.cu index b0dc71c568ba..099512912203 100644 --- a/aten/src/ATen/native/cuda/ScanKernels.cu +++ b/aten/src/ATen/native/cuda/ScanKernels.cu @@ -497,7 +497,8 @@ void scan_cub(const Tensor& self, Tensor& result, scalar_t init, BinaryFunction at::cuda::getCurrentCUDAStream())); auto temp_storage = at::native::empty_cuda( {static_cast(temp_storage_bytes)}, - self.options().dtype(kByte)); + kByte, self.options().layout_opt(), self.options().device_opt(), + self.options().pinned_memory_opt()); AT_CUDA_CHECK(cub::DeviceScan::InclusiveScan( temp_storage.data_ptr(), temp_storage_bytes, diff --git a/aten/src/ATen/native/cuda/TensorFactories.cu b/aten/src/ATen/native/cuda/TensorFactories.cu index 74f6e91fb590..a241f7df533c 100644 --- a/aten/src/ATen/native/cuda/TensorFactories.cu +++ b/aten/src/ATen/native/cuda/TensorFactories.cu @@ -41,15 +41,16 @@ Tensor& eye_out_cuda(Tensor& result, int64_t n, int64_t m) { return result; } -Tensor empty_cuda(IntArrayRef size, const TensorOptions& options, c10::optional optional_memory_format) { - AT_ASSERT(options.device().type() == at::DeviceType::CUDA); - TORCH_CHECK(!options.pinned_memory(), "Only dense CPU tensors can be pinned"); +Tensor empty_cuda(IntArrayRef size, c10::optional dtype_opt, c10::optional layout_opt, c10::optional device_opt, c10::optional pin_memory_opt, c10::optional memory_format_opt) { + AT_ASSERT(device_or_default(device_opt).type() == at::DeviceType::CUDA); + TORCH_CHECK(!pin_memory_opt.has_value() || !*pin_memory_opt, "Only dense CPU tensors can be pinned"); check_size_nonnegative(size); auto* allocator = at::cuda::getCUDADeviceAllocator(); int64_t nelements = prod_intlist(size); - auto dtype = options.dtype(); - int64_t size_bytes = nelements * dtype.itemsize(); + auto dtype = dtype_or_default(dtype_opt); + auto dtype_meta = scalarTypeToTypeMeta(dtype); + int64_t size_bytes = nelements * dtype_meta.itemsize(); auto storage_impl = c10::make_intrusive( c10::StorageImpl::use_byte_size_t(), size_bytes, @@ -58,23 +59,19 @@ Tensor empty_cuda(IntArrayRef size, const TensorOptions& options, c10::optional< /*resizeable=*/true); auto tensor = - detail::make_tensor(storage_impl, DispatchKey::CUDA, dtype); + detail::make_tensor(storage_impl, DispatchKey::CUDA, dtype_meta); // Default TensorImpl has size [0] if (size.size() != 1 || size[0] != 0) { tensor.unsafeGetTensorImpl()->set_sizes_contiguous(size); } - TORCH_CHECK( - !(options.has_memory_format() && optional_memory_format.has_value()), - "Cannot set memory_format both in TensorOptions and explicit argument; please delete " - "the redundant setter."); - auto memory_format = options.memory_format_opt().value_or(optional_memory_format.value_or(MemoryFormat::Contiguous)); + auto memory_format = memory_format_opt.value_or(MemoryFormat::Contiguous); tensor.unsafeGetTensorImpl()->empty_tensor_restride(memory_format); return tensor; } -Tensor empty_strided_cuda(IntArrayRef size, IntArrayRef stride, const TensorOptions& options) { - auto t = at::native::empty_cuda({0}, options); +Tensor empty_strided_cuda(IntArrayRef size, IntArrayRef stride, c10::optional dtype_opt, c10::optional layout_opt, c10::optional device_opt, c10::optional pin_memory_opt) { + auto t = at::native::empty_cuda({0}, dtype_opt, layout_opt, device_opt, pin_memory_opt); at::native::resize_impl_cuda_(t.unsafeGetTensorImpl(), size, stride); return t; } @@ -325,11 +322,12 @@ void tril_indices_kernel(scalar_t * tensor, // implementation, please enable them in test/test_cuda.py and make sure they // pass on your local server. Tensor tril_indices_cuda( - int64_t row, int64_t col, int64_t offset, const TensorOptions& options) { - check_args(row, col, options); + int64_t row, int64_t col, int64_t offset, c10::optional dtype_opt, + c10::optional layout_opt, c10::optional device_opt, c10::optional pin_memory_opt) { + check_args(row, col, layout_opt); auto tril_size = get_tril_size(row, col, offset); - auto tensor = empty_cuda({2, tril_size}, options); + auto tensor = empty_cuda({2, tril_size}, dtype_opt, layout_opt, device_opt, pin_memory_opt); if (tril_size > 0) { auto m_first_row = offset > 0 ? @@ -399,11 +397,12 @@ void triu_indices_kernel(scalar_t * tensor, // implementation, please enable them in test/test_cuda.py and make sure they // pass on your local server. Tensor triu_indices_cuda( - int64_t row, int64_t col, int64_t offset, const TensorOptions& options) { - check_args(row, col, options); + int64_t row, int64_t col, int64_t offset, c10::optional dtype_opt, + c10::optional layout_opt, c10::optional device_opt, c10::optional pin_memory_opt) { + check_args(row, col, layout_opt); auto triu_size = row * col - get_tril_size(row, col, offset - 1); - auto tensor = empty_cuda({2, triu_size}, options); + auto tensor = empty_cuda({2, triu_size}, dtype_opt, layout_opt, device_opt, pin_memory_opt); if (triu_size > 0) { // # of triu elements in the first row diff --git a/aten/src/ATen/native/metal/MetalAten.mm b/aten/src/ATen/native/metal/MetalAten.mm index d9550352b922..77fec7dae1ff 100644 --- a/aten/src/ATen/native/metal/MetalAten.mm +++ b/aten/src/ATen/native/metal/MetalAten.mm @@ -70,17 +70,20 @@ Tensor empty( IntArrayRef size, - const TensorOptions& options, + optional dtype, + optional layout, + optional device, + optional pin_memory, c10::optional memory_format) { TORCH_CHECK( - !options.has_pinned_memory(), + !pin_memory.has_value(), "'pin_memory' argument is incompatible with Metal tensor"); TORCH_CHECK( - !options.has_memory_format() && !memory_format, + !memory_format.has_value(), "'memory_format' argument is incompatible with Metal tensor"); MetalTensor mt{size.vec()}; return MetalTensor::toTensor( - std::move(mt), at::device(at::kMetal).dtype(options.dtype())); + std::move(mt), at::device(at::kMetal).dtype(dtype)); }; at::Tensor empty_strided( @@ -249,7 +252,7 @@ Tensor flatten_using_ints( m.impl("add.Tensor", TORCH_FN(add_Tensor)); m.impl("add_.Tensor", TORCH_FN(add__Tensor)); m.impl("addmm", TORCH_FN(addmm)); - m.impl_UNBOXED("empty.memory_format", empty); + m.impl("empty.memory_format", empty); m.impl("empty_strided", TORCH_FN(empty_strided)); m.impl("log_softmax.int", TORCH_FN(log_softmax_int)); m.impl("max_pool2d", TORCH_FN(max_pool2d)); diff --git a/aten/src/ATen/native/mkldnn/BinaryOps.cpp b/aten/src/ATen/native/mkldnn/BinaryOps.cpp index 3364fe8b335c..029b1d225d14 100644 --- a/aten/src/ATen/native/mkldnn/BinaryOps.cpp +++ b/aten/src/ATen/native/mkldnn/BinaryOps.cpp @@ -68,7 +68,8 @@ Tensor mkldnn_add(const Tensor& self, const Tensor& other, Scalar alpha) { const std::vector scales{1.0, alpha.to()}; ideep::sum::compute(scales, {x, y}, z); - return new_with_itensor_mkldnn(std::move(z), self.options()); + return new_with_itensor_mkldnn(std::move(z), optTypeMetaToScalarType(self.options().dtype_opt()), + self.options().device_opt()); } Tensor& mkldnn_add_(Tensor& self, const Tensor& other, Scalar alpha) { @@ -99,7 +100,9 @@ Tensor& mkldnn_mul_out(Tensor& result, const Tensor& self, const Tensor& other) } Tensor mkldnn_mul(const Tensor& self, const Tensor& other) { - Tensor result = empty_mkldnn(self.sizes(), self.options()); + Tensor result = empty_mkldnn(self.sizes(), optTypeMetaToScalarType(self.options().dtype_opt()), + self.options().layout_opt(), self.options().device_opt(), + self.options().pinned_memory_opt()); return native::mkldnn_mul_out(result, self, other); } diff --git a/aten/src/ATen/native/mkldnn/Conv.cpp b/aten/src/ATen/native/mkldnn/Conv.cpp index 664f7bbd8f1e..8ee584b1bfce 100644 --- a/aten/src/ATen/native/mkldnn/Conv.cpp +++ b/aten/src/ATen/native/mkldnn/Conv.cpp @@ -123,10 +123,12 @@ Tensor mkldnn_convolution( groups); if (input.is_mkldnn()) { - return new_with_itensor_mkldnn(std::move(mkldnn_output), input.options()); + return new_with_itensor_mkldnn(std::move(mkldnn_output), optTypeMetaToScalarType(input.options().dtype_opt()), + input.options().device_opt()); } else { return mkldnn_to_dense( - new_with_itensor_mkldnn(std::move(mkldnn_output), input.options())); + new_with_itensor_mkldnn(std::move(mkldnn_output), optTypeMetaToScalarType(input.options().dtype_opt()), + input.options().device_opt())); } } @@ -150,7 +152,8 @@ Tensor mkldnn_convolution_backward_input( groups); return mkldnn_to_dense(new_with_itensor_mkldnn(std::move(mkldnn_grad_input), - grad_output.options())); + optTypeMetaToScalarType(grad_output.options().dtype_opt()), + grad_output.options().device_opt())); } std::tuple mkldnn_convolution_backward_weights( @@ -188,9 +191,11 @@ std::tuple mkldnn_convolution_backward_weights( return std::make_tuple( mkldnn_to_dense(new_with_itensor_mkldnn(std::move(mkldnn_grad_weight), - grad_output.options())), + optTypeMetaToScalarType(grad_output.options().dtype_opt()), + grad_output.options().device_opt())), mkldnn_to_dense(new_with_itensor_mkldnn(std::move(mkldnn_grad_bias), - grad_output.options()))); + optTypeMetaToScalarType(grad_output.options().dtype_opt()), + grad_output.options().device_opt()))); } std::tuple mkldnn_convolution_backward( diff --git a/aten/src/ATen/native/mkldnn/Linear.cpp b/aten/src/ATen/native/mkldnn/Linear.cpp index bcc9b786b869..21d240ef5279 100644 --- a/aten/src/ATen/native/mkldnn/Linear.cpp +++ b/aten/src/ATen/native/mkldnn/Linear.cpp @@ -54,9 +54,11 @@ Tensor mkldnn_linear( output_size.push_back(weight.size(0)); if (self.dim() > 2) { - return new_with_itensor_mkldnn(std::move(y), self.options()).reshape(output_size); + return new_with_itensor_mkldnn(std::move(y), optTypeMetaToScalarType(self.options().dtype_opt()), + self.options().device_opt()).reshape(output_size); } - return new_with_itensor_mkldnn(std::move(y), self.options()); + return new_with_itensor_mkldnn(std::move(y), optTypeMetaToScalarType(self.options().dtype_opt()), + self.options().device_opt()); } } // namespace native diff --git a/aten/src/ATen/native/mkldnn/MKLDNNCommon.cpp b/aten/src/ATen/native/mkldnn/MKLDNNCommon.cpp index a58174034967..b343f5bd77ee 100644 --- a/aten/src/ATen/native/mkldnn/MKLDNNCommon.cpp +++ b/aten/src/ATen/native/mkldnn/MKLDNNCommon.cpp @@ -40,14 +40,16 @@ using IDeepTensorWrapperPtr = c10::intrusive_ptr; using MKLDNNTensorImpl = OpaqueTensorImpl; using MKLDNNTensor = Tensor; -Tensor new_with_itensor_mkldnn(ideep::tensor&& it, const TensorOptions& options) { +Tensor new_with_itensor_mkldnn(ideep::tensor&& it, c10::optional dtype, c10::optional device) { // NOTE: int32_t dims from ideep::tensor but sizes needs int64_t // TODO: support int64_t dims in ideep::tensor to avoid extra conversion auto dims = it.get_dims(); IDeepTensorWrapperPtr handle = c10::make_intrusive(std::move(it)); + caffe2::TypeMeta dtype_ = scalarTypeToTypeMeta(dtype_or_default(dtype)); + Device device_ = device_or_default(device); return detail::make_tensor( DispatchKeySet(DispatchKey::MkldnnCPU), - options.dtype(), options.device(), handle, + dtype_, device_, handle, std::vector(dims.begin(), dims.end())); } diff --git a/aten/src/ATen/native/mkldnn/MKLDNNCommon.h b/aten/src/ATen/native/mkldnn/MKLDNNCommon.h index 0167b8183d46..86d67ac823a3 100644 --- a/aten/src/ATen/native/mkldnn/MKLDNNCommon.h +++ b/aten/src/ATen/native/mkldnn/MKLDNNCommon.h @@ -9,7 +9,7 @@ namespace at { namespace native { // Construct aten MKL-DNN tensor given an ideep tensor -Tensor new_with_itensor_mkldnn(ideep::tensor&& it, const TensorOptions& options); +Tensor new_with_itensor_mkldnn(ideep::tensor&& it, c10::optional dtype, c10::optional device); // Retrieve `ideep::tensor` from MKL-DNN tensor ideep::tensor& itensor_from_mkldnn(const Tensor& mkldnn_tensor); diff --git a/aten/src/ATen/native/mkldnn/MKLDNNConversions.cpp b/aten/src/ATen/native/mkldnn/MKLDNNConversions.cpp index 971fa7a3af2f..743503d3264c 100644 --- a/aten/src/ATen/native/mkldnn/MKLDNNConversions.cpp +++ b/aten/src/ATen/native/mkldnn/MKLDNNConversions.cpp @@ -32,7 +32,9 @@ Tensor dense_to_mkldnn(const Tensor& cpu_tensor) { "Can't convert cpu tensor with the number of dimensions > 5"); // TODO: consider to convert non-contiguous tensor to `ideep::tensor` directly. auto cpu_tensor_cont = cpu_tensor.contiguous(); - Tensor mkldnn_tensor = empty_mkldnn(cpu_tensor_cont.sizes(), cpu_tensor_cont.options()); + Tensor mkldnn_tensor = empty_mkldnn(cpu_tensor_cont.sizes(), optTypeMetaToScalarType(cpu_tensor_cont.options().dtype_opt()), + cpu_tensor_cont.options().layout_opt(), cpu_tensor_cont.options().device_opt(), + cpu_tensor_cont.options().pinned_memory_opt()); ideep::tensor& dtensor = itensor_from_mkldnn(mkldnn_tensor); dtensor.feed_from(dtensor.get_dims(), ideep::tensor::data_type::f32, @@ -79,7 +81,8 @@ Tensor mkldnn_reorder_conv2d_weight( result.init(desc); result.feed_from(w); - return new_with_itensor_mkldnn(std::move(result), self.options()); + return new_with_itensor_mkldnn(std::move(result), optTypeMetaToScalarType(self.options().dtype_opt()), + self.options().device_opt()); } Tensor mkldnn_reorder_conv3d_weight( @@ -105,7 +108,7 @@ Tensor mkldnn_reorder_conv3d_weight( result.init(desc); result.feed_from(w); - return new_with_itensor_mkldnn(std::move(result), self.options()); + return new_with_itensor_mkldnn(std::move(result), optTypeMetaToScalarType(self.options().dtype_opt()), self.options().device_opt()); } #else diff --git a/aten/src/ATen/native/mkldnn/Normalization.cpp b/aten/src/ATen/native/mkldnn/Normalization.cpp index 86d9d0643a27..ca331392acd8 100644 --- a/aten/src/ATen/native/mkldnn/Normalization.cpp +++ b/aten/src/ATen/native/mkldnn/Normalization.cpp @@ -56,18 +56,24 @@ std::tuple mkldnn_batch_norm( // ideep::batch_normalization_forward_training::compute( // x, w, b, y, saved_mean, saved_var, m, v, momentum, eps); // return std::make_tuple( - // new_with_itensor_mkldnn(std::move(y), input.options()), - // new_with_itensor_mkldnn(std::move(saved_mean), input.options()), - // new_with_itensor_mkldnn(std::move(saved_var), input.options())); + // new_with_itensor_mkldnn(std::move(y), optTypeMetaToScalarType(input.options().dtype_opt()), + // input.options().device_opt()), + // new_with_itensor_mkldnn(std::move(saved_mean), optTypeMetaToScalarType(input.options().dtype_opt()), + // input.options().device_opt()), + // new_with_itensor_mkldnn(std::move(saved_var), optTypeMetaToScalarType(input.options().dtype_opt()), + // input.options().device_opt())); } else { TORCH_CHECK(input.dim() == 4 || input.dim() == 5, "mkldnn_batch_norm: currently mkldnn only support 2d and 3d batchnorm"); ideep::batch_normalization_forward_inference::compute( x, m, v, w, b, y, eps); return std::make_tuple( - new_with_itensor_mkldnn(std::move(y), input.options()), - new_with_itensor_mkldnn(ideep::tensor{}, input.options()), - new_with_itensor_mkldnn(ideep::tensor{}, input.options())); + new_with_itensor_mkldnn(std::move(y), optTypeMetaToScalarType(input.options().dtype_opt()), + input.options().device_opt()), + new_with_itensor_mkldnn(ideep::tensor{}, optTypeMetaToScalarType(input.options().dtype_opt()), + input.options().device_opt()), + new_with_itensor_mkldnn(ideep::tensor{}, optTypeMetaToScalarType(input.options().dtype_opt()), + input.options().device_opt())); } } diff --git a/aten/src/ATen/native/mkldnn/Pooling.cpp b/aten/src/ATen/native/mkldnn/Pooling.cpp index a272bc3d6070..5f744f494443 100644 --- a/aten/src/ATen/native/mkldnn/Pooling.cpp +++ b/aten/src/ATen/native/mkldnn/Pooling.cpp @@ -174,7 +174,7 @@ static Tensor _mkldnn_pooling( algo, ideep::prop_kind::forward); - return new_with_itensor_mkldnn(std::move(y), input.options()); + return new_with_itensor_mkldnn(std::move(y), optTypeMetaToScalarType(input.options().dtype_opt()), input.options().device_opt()); } Tensor mkldnn_max_pool2d( diff --git a/aten/src/ATen/native/mkldnn/Relu.cpp b/aten/src/ATen/native/mkldnn/Relu.cpp index 42397255caf0..6915447980bb 100644 --- a/aten/src/ATen/native/mkldnn/Relu.cpp +++ b/aten/src/ATen/native/mkldnn/Relu.cpp @@ -28,7 +28,8 @@ Tensor mkldnn_relu(const Tensor& input) { ideep::tensor y; ideep::eltwise_forward::compute( x, y, ideep::algorithm::eltwise_relu, ideep::prop_kind::forward_training, /*alpha*/ 0.0); - return new_with_itensor_mkldnn(std::move(y), input.options()); + return new_with_itensor_mkldnn(std::move(y), optTypeMetaToScalarType(input.options().dtype_opt()), + input.options().device_opt()); } Tensor& mkldnn_relu_(Tensor& input) { diff --git a/aten/src/ATen/native/mkldnn/SoftMax.cpp b/aten/src/ATen/native/mkldnn/SoftMax.cpp index cdeb6cb85971..861cca0aae53 100644 --- a/aten/src/ATen/native/mkldnn/SoftMax.cpp +++ b/aten/src/ATen/native/mkldnn/SoftMax.cpp @@ -35,7 +35,8 @@ Tensor mkldnn_softmax( ideep::tensor& x = itensor_from_mkldnn(self); ideep::tensor y; ideep::softmax_forward::compute(x, y, wrapped_dim); - return new_with_itensor_mkldnn(std::move(y), self.options()); + return new_with_itensor_mkldnn(std::move(y), optTypeMetaToScalarType(self.options().dtype_opt()), + self.options().device_opt()); } } // namespace native diff --git a/aten/src/ATen/native/mkldnn/TensorFactories.cpp b/aten/src/ATen/native/mkldnn/TensorFactories.cpp index 603819ed3287..75d20b8bcaba 100644 --- a/aten/src/ATen/native/mkldnn/TensorFactories.cpp +++ b/aten/src/ATen/native/mkldnn/TensorFactories.cpp @@ -4,10 +4,7 @@ namespace at { namespace native { #if AT_MKLDNN_ENABLED() -Tensor empty_mkldnn(IntArrayRef sizes, const TensorOptions& options, c10::optional optional_memory_format) { - TORCH_CHECK( - !options.has_memory_format(), - "'memory_format' argument is incompatible with mkldnn tensor"); +Tensor empty_mkldnn(IntArrayRef sizes, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, c10::optional optional_memory_format) { TORCH_CHECK( !optional_memory_format.has_value(), "'memory_format' argument is incompatible with mkldnn tensor"); @@ -15,12 +12,12 @@ Tensor empty_mkldnn(IntArrayRef sizes, const TensorOptions& options, c10::option // TODO: support int64_t dims in ideep::tensor to avoid extra conversion ideep::tensor::dims dst_dims (sizes.begin(), sizes.end()); ideep::tensor it {dst_dims, ideep::tensor::data_type::f32}; - return new_with_itensor_mkldnn(std::move(it), options); + return new_with_itensor_mkldnn(std::move(it), dtype, device); } #else -Tensor empty_mkldnn(IntArrayRef sizes, const TensorOptions& options, c10::optional optional_memory_format) { +Tensor empty_mkldnn(IntArrayRef sizes, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, c10::optional optional_memory_format) { TORCH_CHECK(false, "empty_mkldnn: MKL-DNN build is disabled"); } diff --git a/aten/src/ATen/native/mkldnn/TensorShape.cpp b/aten/src/ATen/native/mkldnn/TensorShape.cpp index 3229a07e9460..6e31a3a8aa93 100644 --- a/aten/src/ATen/native/mkldnn/TensorShape.cpp +++ b/aten/src/ATen/native/mkldnn/TensorShape.cpp @@ -51,7 +51,8 @@ Tensor mkldnn_reshape(const Tensor& self, IntArrayRef size) { const ideep::tensor& x = itensor_from_mkldnn(self); ideep::tensor y{x}; y.reshape(inferred_size); - return new_with_itensor_mkldnn(std::move(y), self.options()); + return new_with_itensor_mkldnn(std::move(y), optTypeMetaToScalarType(self.options().dtype_opt()), + self.options().device_opt()); } Tensor mkldnn_clone(const Tensor& self, c10::optional optional_memory_format) { @@ -62,7 +63,8 @@ Tensor mkldnn_clone(const Tensor& self, c10::optional optiona ideep::tensor& src = itensor_from_mkldnn(self); ideep::tensor dst; ideep::direct_copy::compute(src, dst); - return new_with_itensor_mkldnn(std::move(dst), self.options()); + return new_with_itensor_mkldnn(std::move(dst), optTypeMetaToScalarType(self.options().dtype_opt()), + self.options().device_opt()); } Tensor mkldnn_transpose(const Tensor& self, int64_t dim0, int64_t dim1) { @@ -72,7 +74,8 @@ Tensor mkldnn_transpose(const Tensor& self, int64_t dim0, int64_t dim1) { std::iota(axes.begin(), axes.end(), 0); std::swap(axes[dim0], axes[dim1]); y.transpose_from(x, axes); - return new_with_itensor_mkldnn(std::move(y), self.options()); + return new_with_itensor_mkldnn(std::move(y), optTypeMetaToScalarType(self.options().dtype_opt()), + self.options().device_opt()); } Tensor& mkldnn_transpose_(Tensor& self, int64_t dim0, int64_t dim1) { diff --git a/aten/src/ATen/native/mkldnn/UnaryOps.cpp b/aten/src/ATen/native/mkldnn/UnaryOps.cpp index 4eb02dc483c5..1434512b5241 100644 --- a/aten/src/ATen/native/mkldnn/UnaryOps.cpp +++ b/aten/src/ATen/native/mkldnn/UnaryOps.cpp @@ -30,7 +30,8 @@ Tensor mkldnn_sigmoid(const Tensor& self) { ideep::tensor y; ideep::eltwise_forward::compute( x, y, ideep::algorithm::eltwise_logistic, ideep::prop_kind::forward); - return new_with_itensor_mkldnn(std::move(y), self.options()); + return new_with_itensor_mkldnn(std::move(y), optTypeMetaToScalarType(self.options().dtype_opt()), + self.options().device_opt()); } Tensor& mkldnn_sigmoid_(Tensor& self) { diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 2dcd3d234e46..08cae18d4ae0 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -1612,13 +1612,13 @@ CUDA: _embedding_bag_per_sample_weights_backward_cuda - func: empty_meta(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor - #use_c10_dispatcher: full + use_c10_dispatcher: full - func: empty.names(int[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor device_guard: False - func: empty.memory_format(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor - #use_c10_dispatcher: full + use_c10_dispatcher: full dispatch: CPU: empty_cpu CUDA: empty_cuda @@ -1626,7 +1626,7 @@ SparseCPU, SparseCUDA: empty_sparse - func: new_empty(Tensor self, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - #use_c10_dispatcher: full + use_c10_dispatcher: full variants: method - func: new_empty_strided(Tensor self, int[] size, int[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor @@ -1679,7 +1679,7 @@ device_guard: False - func: empty_strided(int[] size, int[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + use_c10_dispatcher: full dispatch: CPU: empty_strided_cpu CUDA: empty_strided_cuda @@ -4595,12 +4595,12 @@ use_c10_dispatcher: full - func: _sparse_coo_tensor_with_dims(int sparse_dim, int dense_dim, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor - use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + use_c10_dispatcher: full dispatch: SparseCPU, SparseCUDA: new_with_dims_sparse - func: _sparse_coo_tensor_with_dims_and_tensors(int sparse_dim, int dense_dim, int[] size, Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor - use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + use_c10_dispatcher: full dispatch: SparseCPU, SparseCUDA: new_with_dims_and_tensor_sparse @@ -5680,13 +5680,13 @@ DefaultBackend: tril - func: tril_indices(int row, int col, int offset=0, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + use_c10_dispatcher: full dispatch: CPU: tril_indices_cpu CUDA: tril_indices_cuda - func: triu_indices(int row, int col, int offset=0, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + use_c10_dispatcher: full dispatch: CPU: triu_indices_cpu CUDA: triu_indices_cuda diff --git a/aten/src/ATen/native/sparse/SparseTensor.cpp b/aten/src/ATen/native/sparse/SparseTensor.cpp index 7f58f1631bc4..1eac8efa7a24 100644 --- a/aten/src/ATen/native/sparse/SparseTensor.cpp +++ b/aten/src/ATen/native/sparse/SparseTensor.cpp @@ -70,22 +70,23 @@ Tensor values_sparse(const Tensor& self) { /*** Helper methods ***/ -SparseTensor new_sparse(const TensorOptions& options) { - AT_ASSERT(options.layout() == kSparse); +SparseTensor new_sparse(c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) { + AT_ASSERT(layout.has_value() && *layout == kSparse); DispatchKey dispatch_key; - if (options.device().is_cuda()) { + if (device_or_default(device).is_cuda()) { dispatch_key = DispatchKey::SparseCUDA; } else { dispatch_key = DispatchKey::SparseCPU; } return detail::make_tensor( - DispatchKeySet(dispatch_key), options.dtype()); + DispatchKeySet(dispatch_key), scalarTypeToTypeMeta(dtype_or_default(dtype))); } /** Actual dispatched creation methods ***/ -SparseTensor new_with_dims_sparse(int64_t sparse_dim, int64_t dense_dim, ArrayRef size, const TensorOptions& options) { - SparseTensor self = new_sparse(options); +SparseTensor new_with_dims_sparse(int64_t sparse_dim, int64_t dense_dim, ArrayRef size, c10::optional dtype, + c10::optional layout, c10::optional device, c10::optional pin_memory) { + SparseTensor self = new_sparse(dtype, layout, device, pin_memory); get_sparse_impl(self)->resize_and_clear_(sparse_dim, dense_dim, size); return self; } @@ -96,8 +97,11 @@ SparseTensor new_with_dims_and_tensor_sparse( ArrayRef size, const LongTensor& indices, const Tensor& values, - const TensorOptions& options) { - SparseTensor self = new_sparse(options); + c10::optional dtype, + c10::optional layout, + c10::optional device, + c10::optional pin_memory) { + SparseTensor self = new_sparse(dtype, layout, device, pin_memory); get_sparse_impl(self)->resize_(sparse_dim, dense_dim, size); // NOTE: There is no guarantee that `indices` and `values` don't contain AutogradMeta. However, // we want to maintain the invariant that `indices_` and `values_` of a sparse tensor don't @@ -115,9 +119,9 @@ SparseTensor new_with_dims_and_tensor_sparse( /** Public creation API that dispatch to methods above **/ /** Empty init **/ -Tensor empty_sparse(IntArrayRef size, const TensorOptions& options, c10::optional optional_memory_format) { - TORCH_CHECK(!options.pinned_memory(), "Only dense CPU tensors can be pinned"); - return new_with_dims_sparse(size.size(), 0, size, options); +Tensor empty_sparse(IntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, c10::optional optional_memory_format) { + TORCH_CHECK(!pin_memory.has_value() || !*pin_memory, "Only dense CPU tensors can be pinned"); + return new_with_dims_sparse(size.size(), 0, size, dtype, layout, device, pin_memory); } /* Shape init */ @@ -260,7 +264,9 @@ SparseTensor clone_sparse(const SparseTensor& self, c10::optionalresize_(sparse_dim, dense_dim, self.sizes()); // TODO: is there a more idiomatic way to do this? LongTensor newIndices = at::empty(indices.sizes(), indices.options()); diff --git a/aten/src/ATen/native/vulkan/VulkanAten.cpp b/aten/src/ATen/native/vulkan/VulkanAten.cpp index 18df3ae818f3..4dba9de7d5b0 100644 --- a/aten/src/ATen/native/vulkan/VulkanAten.cpp +++ b/aten/src/ATen/native/vulkan/VulkanAten.cpp @@ -55,17 +55,20 @@ VulkanTensor& vtensor_from_vulkan(Tensor& tensor) { Tensor empty( IntArrayRef size, - const TensorOptions& options, + optional dtype, + optional layout, + optional device, + optional pin_memory, const optional memory_format) { TORCH_CHECK( - !options.pinned_memory(), + !pin_memory.has_value(), "'pin_memory' argument is incompatible with Vulkan tensor"); TORCH_CHECK( - !options.has_memory_format() && !memory_format, + !memory_format.has_value(), "'memory_format' argument is incompatible with Vulkan tensor"); VulkanTensor vt{size.vec()}; return new_with_vtensor_vulkan( - std::move(vt), at::device(at::kVulkan).dtype(options.dtype())); + std::move(vt), at::device(at::kVulkan).dtype(dtype)); } Tensor empty_strided( @@ -76,7 +79,7 @@ Tensor empty_strided( optional device, optional pin_memory) { return vulkan::aten::empty( - size, TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory), c10::nullopt); + size, dtype, layout, device, pin_memory, c10::nullopt); } Tensor upsample_nearest2d( @@ -548,7 +551,7 @@ TORCH_LIBRARY_IMPL(aten, Vulkan, m) { m.impl_UNBOXED("transpose_", at::native::vulkan::aten::transpose_); m.impl("view", TORCH_FN(at::native::vulkan::aten::view)); m.impl("unsqueeze", TORCH_FN(at::native::vulkan::aten::unsqueeze)); - m.impl_UNBOXED("empty.memory_format", at::native::vulkan::aten::empty); + m.impl("empty.memory_format", at::native::vulkan::aten::empty); m.impl("empty_strided", TORCH_FN(at::native::vulkan::aten::empty_strided)); m.impl("add.Tensor", TORCH_FN(at::native::vulkan::aten::add)); m.impl("clamp", TORCH_FN(at::native::vulkan::aten::clamp)); diff --git a/aten/src/ATen/test/basic.cpp b/aten/src/ATen/test/basic.cpp index 144c4671e50f..a81a9a06cea6 100644 --- a/aten/src/ATen/test/basic.cpp +++ b/aten/src/ATen/test/basic.cpp @@ -370,14 +370,8 @@ TEST(BasicTest, FactoryMethodsTest) { ASSERT_FALSE(tensor0.is_pinned()); // Test setting requires_grad to true. - tensor0 = at::empty({4}, at::TensorOptions().requires_grad(true)); - ASSERT_EQ(tensor0.dtype(), at::kFloat); - ASSERT_EQ(tensor0.layout(), at::kStrided); - ASSERT_EQ(tensor0.device(), at::kCPU); - // This is a bug. Requires_grad was set to TRUE but this is being ignored. - // Issue https://github.com/pytorch/pytorch/issues/30405 - ASSERT_FALSE(tensor0.requires_grad()); - ASSERT_FALSE(tensor0.is_pinned()); + // This is a bug. Requires_grad was set to TRUE but this is not implemented. + EXPECT_ANY_THROW(at::empty({4}, at::TensorOptions().requires_grad(true))); // Test setting dtype at::Tensor tensor1 = at::empty({4}, at::TensorOptions().dtype(at::kHalf)); diff --git a/aten/src/ATen/test/extension_backend_test.cpp b/aten/src/ATen/test/extension_backend_test.cpp index ad7b7b2074e7..f0ec67a49ac0 100644 --- a/aten/src/ATen/test/extension_backend_test.cpp +++ b/aten/src/ATen/test/extension_backend_test.cpp @@ -10,7 +10,8 @@ using namespace at; static int test_int; -Tensor empty_override(IntArrayRef size, const TensorOptions & options, c10::optional optional_memory_format) { +Tensor empty_override(IntArrayRef size, c10::optional dtype, c10::optional layout, + c10::optional device, c10::optional pin_memory, c10::optional optional_memory_format) { test_int = 1; auto tensor_impl = c10::make_intrusive( Storage( @@ -37,13 +38,13 @@ Tensor empty_strided_override( c10::optional device, c10::optional pin_memory) { - return empty_override(size, at::kMSNPU, c10::nullopt); + return empty_override(size, dtype, layout, device, pin_memory, c10::nullopt); } TORCH_LIBRARY_IMPL(aten, MSNPU, m) { - m.impl_UNBOXED("aten::empty.memory_format", empty_override); - m.impl_UNBOXED("aten::empty_strided", empty_strided_override); - m.impl_UNBOXED("aten::add.Tensor", add_override); + m.impl("aten::empty.memory_format", empty_override); + m.impl("aten::empty_strided", empty_strided_override); + m.impl("aten::add.Tensor", add_override); } TEST(BackendExtensionTest, TestRegisterOp) { diff --git a/c10/core/TensorOptions.h b/c10/core/TensorOptions.h index c8c7f058513d..420e0a984b77 100644 --- a/c10/core/TensorOptions.h +++ b/c10/core/TensorOptions.h @@ -22,23 +22,23 @@ namespace c10 { DispatchKey computeDispatchKey(c10::optional dtype, c10::optional layout, c10::optional device); inline ScalarType dtype_or_default(c10::optional dtype) { - return dtype.has_value() ? *dtype : get_default_dtype_as_scalartype(); + return value_or_else(dtype, [] {return get_default_dtype_as_scalartype();}); } inline caffe2::TypeMeta dtype_or_default(c10::optional dtype) { - return dtype.has_value() ? *dtype : get_default_dtype(); + return value_or_else(dtype, [] {return get_default_dtype();}); } inline Layout layout_or_default(c10::optional layout) { - return layout.has_value() ? *layout : kStrided; + return layout.value_or(kStrided); } inline Device device_or_default(c10::optional device) { - return device.has_value() ? *device : Device(kCPU); + return value_or_else(device, [] {return Device(kCPU);}); } inline bool pinned_memory_or_default(c10::optional pinned_memory) { - return pinned_memory.has_value() ? *pinned_memory : false; + return pinned_memory.value_or(false); } /// A class to encapsulate construction axes of an Tensor. TensorOptions was @@ -121,6 +121,8 @@ inline bool pinned_memory_or_default(c10::optional pinned_memory) { /// To get around this, we templatize the `Device` constructor. Since overload /// resolution is done before template resolution, our problem is solved. +DispatchKey computeDispatchKey(optional dtype, optional layout, optional device); + struct C10_API TensorOptions { TensorOptions() @@ -402,7 +404,7 @@ struct C10_API TensorOptions { return DispatchKeySet(computeDispatchKey()); } - inline DispatchKey computeDispatchKey() const { + DispatchKey computeDispatchKey() const { return c10::computeDispatchKey(optTypeMetaToScalarType(dtype_opt()), layout_opt(), device_opt()); } diff --git a/c10/util/Optional.h b/c10/util/Optional.h index 62a9f3f513d0..4d7d4b507771 100644 --- a/c10/util/Optional.h +++ b/c10/util/Optional.h @@ -36,6 +36,8 @@ #include #include +#include + #define TR2_OPTIONAL_REQUIRES(...) \ typename std::enable_if<__VA_ARGS__::value, bool>::type = false @@ -643,6 +645,22 @@ class optional : private OptionalBase { } }; +template +constexpr T value_or_else(const optional& v, F&& func) { + static_assert(std::is_convertible::return_type, T>::value, + "func parameters must be a callable that returns a type convertible to the value stored in the optional"); + return v.has_value() ? *v : detail_::convert(std::forward(func)()); +} + +template +constexpr T value_or_else(optional&& v, F&& func) { + static_assert(std::is_convertible::return_type, T>::value, + "func parameters must be a callable that returns a type convertible to the value stored in the optional"); + return v.has_value() + ? constexpr_move(std::move(v).contained_val()) + : detail_::convert(std::forward(func)()); +} + // XXX: please refrain from using optional, since it is being against with // the optional standard in c++ 17, see the debate and the details here: diff --git a/test/cpp_extensions/msnpu_extension.cpp b/test/cpp_extensions/msnpu_extension.cpp index 62f046e0037c..88c1d509b34c 100644 --- a/test/cpp_extensions/msnpu_extension.cpp +++ b/test/cpp_extensions/msnpu_extension.cpp @@ -20,9 +20,10 @@ Tensor get_tensor(caffe2::TypeMeta dtype, IntArrayRef size) { return Tensor(std::move(tensor_impl)); } -Tensor empty_override(IntArrayRef size, const TensorOptions& options, c10::optional optional_memory_format) { +Tensor empty_override(IntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, + c10::optional pin_memory, c10::optional optional_memory_format) { test_int = 0; - return get_tensor(options.dtype(), size); + return get_tensor(scalarTypeToTypeMeta(dtype_or_default(dtype)), size); } Tensor add_override(const Tensor & a, const Tensor & b , Scalar c) { diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index cb2df20492cd..ccce10e820ce 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -399,6 +399,7 @@ def gen_variable_type_shard(out, aten_declarations, template_path, suffix, heade strategy = dispatch_strategy(declaration) if declaration['name'] not in MANUAL_AUTOGRAD and strategy == 'use_derived': body = emit_body(declaration) + type_definitions.append(METHOD_DEFINITION.substitute( declaration, type_definition_body=body, formals=formals)) if declaration['use_c10_dispatcher'] in ['full', 'hacky_wrapper_for_legacy_signatures']: diff --git a/torch/csrc/jit/codegen/cuda/executor.cpp b/torch/csrc/jit/codegen/cuda/executor.cpp index f9733201ec1c..76ba1faf6641 100644 --- a/torch/csrc/jit/codegen/cuda/executor.cpp +++ b/torch/csrc/jit/codegen/cuda/executor.cpp @@ -175,17 +175,18 @@ at::Tensor inferAndAlloc( } auto at_type = data_type_to_aten(tv->getDataType().value()); - auto tensor_options = - at::TensorOptions().dtype(at_type).device(options.device); if (zero_init) { + auto tensor_options = + at::TensorOptions().dtype(at_type).device(options.device); c10::IntArrayRef isizes(sizes); return at::zeros(isizes, tensor_options); } else { c10::IntArrayRef isizes(sizes); // Non Variable type guard for empty_cuda call at::AutoNonVariableTypeMode non_variable_type_mode; - return at::native::empty_cuda(isizes, tensor_options); + return at::native::empty_cuda( + isizes, at_type, c10::nullopt, options.device, c10::nullopt); } } @@ -411,18 +412,20 @@ std::vector FusionExecutor::runFusion( // take the short-cut for launch if we see a recorded input set again; launch_params = executor_entry->launch_params; for (size_t i = 0; i < executor_entry->output_sizes.size(); i++) { - auto tensor_options = at::TensorOptions() - .dtype(executor_entry->output_types[i]) - .device(options_.device); alloced_outputs.push_back(at::native::empty_cuda( - executor_entry->output_sizes[i], tensor_options)); + executor_entry->output_sizes[i], + executor_entry->output_types[i], + c10::nullopt, + options_.device, + c10::nullopt)); } for (size_t i = 0; i < executor_entry->empty_buffer_sizes.size(); i++) { - auto tensor_options = at::TensorOptions() - .dtype(executor_entry->empty_buffer_types[i]) - .device(options_.device); global_buffers.empty_buffers.push_back(at::native::empty_cuda( - executor_entry->empty_buffer_sizes[i], tensor_options)); + executor_entry->empty_buffer_sizes[i], + executor_entry->empty_buffer_types[i], + c10::nullopt, + options_.device, + c10::nullopt)); } } for (size_t i = 0; i < executor_entry->zero_buffer_sizes.size(); i++) { From 8ff0b6fef822edae1b557649b09dccfdaa52a2e3 Mon Sep 17 00:00:00 2001 From: Yang Wang Date: Thu, 12 Nov 2020 17:08:56 -0800 Subject: [PATCH 83/93] [OpBenchMobile] Enable operator_benchmark to run the benchmark on mobile through AiBench (#47767) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47767 This diff implements the functionality of running benchmark on mobile on top of operator_benchmark framework. It does so through a few steps: 1. create a scripted module from existing benchmark case. 2. run mobile specific optimization pass on the scripted module 3. run the scripted module on AiBench by calling its Python API A small change in the way of writing a benchmark case is introduced so that both local and mobile run can share the same interface. The change is about having inputs as arguments of the `forward` function, so that mobile optimization pass can be run successfully (otherwise everything will be optimized away by constant propagation). Test Plan: ## local op_bench run buck run caffe2/benchmarks/operator_benchmark:benchmark_all_test -- --iterations 1 --warmup_iterations 1 buck run caffe2/benchmarks/operator_benchmark:benchmark_all_test -- --iterations 1 --warmup_iterations 1 --use_jit Exceptions: `py_module` op in `FakeQuantizePerTensorBaseOpBenchmark` and `FakeQuantizePerChannelBaseOpBenchmark` under JIT mode. These tests also failed in the base version ``` RuntimeError: Module 'FakeQuantizePerChannelOpBenchmark' has no attribute 'op_func' (This function exists as an attribute on the Python module, but we failed to compile it to a TorchScript function. The error stack is reproduced here: Python builtin is currently not supported in Torchscript: File "/data/users/wangyang19/fbsource/fbcode/buck-out/dev/gen/caffe2/benchmarks/operator_benchmark/pt/quantization_test#link-tree/quantization_test.py", line 260 quant_min: int, quant_max: int ): return _LearnableFakeQuantizePerChannelOp.apply(input, scale, zero_point, axis, quant_min, quant_max, 1.0) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE : File "/data/users/wangyang19/fbsource/fbcode/buck-out/dev/gen/caffe2/benchmarks/operator_benchmark/pt/quantization_test#link-tree/quantization_test.py", line 313 axis: int, quant_min: int, quant_max: int ): return self.op_func(input, scale, zero_point, axis, quant_min, quant_max) ~~~~~~~~~~~~ <--- HERE ``` `_consume_op` typing mismatch: chunk, split, qobserver, sort in qunary. These will be fixed in D24774105 ## OSS test python3 -m benchmark_all_test --iterations 1 --warmup_iterations 1 --use_jit python3 -m benchmark_all_test --iterations 1 --warmup_iterations 1 ## saved module graph ``` module __torch__.mobile_benchmark_utils.OpBenchmarkMobile { parameters { } attributes { training = True num_iters = 1 benchmark = <__torch__.pt.add_test.___torch_mangle_4.AddBenchmark object at 0x6070001b8b50> } methods { method forward { graph(%self : __torch__.mobile_benchmark_utils.OpBenchmarkMobile): %12 : None = prim::Constant() # /data/users/wangyang19/fbsource/fbcode/buck-out/dev/gen/caffe2/benchmarks/operator_benchmark/fb/pt/mobile/benchmark_all_test_fbcode#link-tree/mobile_benchmark_utils.py:9:4 %4 : bool = prim::Constant[value=1]() # /data/users/wangyang19/fbsource/fbcode/buck-out/dev/gen/caffe2/benchmarks/operator_benchmark/fb/pt/mobile/benchmark_all_test_fbcode#link-tree/mobile_benchmark_utils.py:10:8 %1 : int = prim::GetAttr[name="num_iters"](%self) = prim::Loop(%1, %4) # /data/users/wangyang19/fbsource/fbcode/buck-out/dev/gen/caffe2/benchmarks/operator_benchmark/fb/pt/mobile/benchmark_all_test_fbcode#link-tree/mobile_benchmark_utils.py:10:8 block0(%i : int): %6 : __torch__.pt.add_test.___torch_mangle_4.AddBenchmark = prim::GetAttr[name="benchmark"](%self) %7 : __torch__.pt.add_test.___torch_mangle_4.AddBenchmark = prim::GetAttr[name="benchmark"](%self) %self.inputs_tuple : (Float(1, 1, 1, strides=[1, 1, 1], requires_grad=0, device=cpu), Float(1, 1, 1, strides=[1, 1, 1], requires_grad=0, device=cpu)) = prim::Constant[value=({0.48884}, {0.809042})]() %9 : Tensor, %10 : Tensor = prim::TupleUnpack(%self.inputs_tuple) %23 : int = prim::Constant[value=1]() %24 : Tensor = aten::add(%9, %10, %23) # /data/users/wangyang19/fbsource/fbcode/buck-out/dev/gen/caffe2/benchmarks/operator_benchmark/fb/pt/mobile/benchmark_all_test_fbcode#link-tree/pt/add_test.py:39:15 -> (%4) return (%12) } } submodules { module __torch__.pt.add_test.___torch_mangle_4.AddBenchmark { parameters { } attributes { mobile_optimized = True } methods { method forward { graph(%self : __torch__.pt.add_test.___torch_mangle_4.AddBenchmark, %input_one.1 : Tensor, %input_two.1 : Tensor): %3 : int = prim::Constant[value=1]() %4 : Tensor = aten::add(%input_one.1, %input_two.1, %3) # /data/users/wangyang19/fbsource/fbcode/buck-out/dev/gen/caffe2/benchmarks/operator_benchmark/fb/pt/mobile/benchmark_all_test_fbcode#link-tree/pt/add_test.py:39:15 return (%4) } method get_inputs { graph(%self : __torch__.pt.add_test.___torch_mangle_4.AddBenchmark): %self.inputs_tuple : (Float(1, 1, 1, strides=[1, 1, 1], requires_grad=0, device=cpu), Float(1, 1, 1, strides=[1, 1, 1], requires_grad=0, device=cpu)) = prim::Constant[value=({0.48884}, {0.809042})]() return (%self.inputs_tuple) } } submodules { } } } } ``` Reviewed By: kimishpatel Differential Revision: D24322214 fbshipit-source-id: 335317eca4f40c4083883eb41dc47caf25cbdfd1 --- benchmarks/operator_benchmark/README.md | 75 ++++--- .../benchmark_all_other_test.py | 5 +- .../benchmark_all_quantized_test.py | 1 + .../operator_benchmark/benchmark_caffe2.py | 4 + .../operator_benchmark/benchmark_core.py | 1 + .../operator_benchmark/benchmark_pytorch.py | 61 +++--- .../operator_benchmark/benchmark_runner.py | 16 +- .../benchmark_test_generator.py | 4 + benchmarks/operator_benchmark/pt/add_test.py | 46 +++-- .../operator_benchmark/pt/as_strided_test.py | 17 +- .../operator_benchmark/pt/batchnorm_test.py | 16 +- .../operator_benchmark/pt/binary_test.py | 25 ++- benchmarks/operator_benchmark/pt/cat_test.py | 14 +- .../pt/channel_shuffle_test.py | 13 +- .../operator_benchmark/pt/chunk_test.py | 14 +- .../operator_benchmark/pt/clip_ranges_test.py | 11 +- benchmarks/operator_benchmark/pt/conv_test.py | 48 +++-- benchmarks/operator_benchmark/pt/diag_test.py | 16 +- .../pt/embeddingbag_test.py | 11 +- benchmarks/operator_benchmark/pt/fill_test.py | 8 +- .../operator_benchmark/pt/gather_test.py | 12 +- .../operator_benchmark/pt/groupnorm_test.py | 18 +- .../operator_benchmark/pt/hardsigmoid_test.py | 8 +- .../operator_benchmark/pt/hardswish_test.py | 8 +- .../pt/instancenorm_test.py | 16 +- .../operator_benchmark/pt/layernorm_test.py | 17 +- .../operator_benchmark/pt/linear_test.py | 8 +- .../operator_benchmark/pt/matmul_test.py | 16 +- .../operator_benchmark/pt/nan_to_num_test.py | 42 ++-- benchmarks/operator_benchmark/pt/pool_test.py | 36 ++-- .../operator_benchmark/pt/qactivation_test.py | 46 +++-- .../operator_benchmark/pt/qarithmetic_test.py | 75 ++++--- .../operator_benchmark/pt/qbatchnorm_test.py | 52 +++-- benchmarks/operator_benchmark/pt/qcat_test.py | 10 +- .../pt/qcomparators_test.py | 27 ++- .../operator_benchmark/pt/qconv_test.py | 24 ++- .../pt/qembedding_bag_lookups_test.py | 71 +++++-- .../pt/qembedding_pack_test.py | 18 +- .../pt/qembeddingbag_test.py | 8 +- .../operator_benchmark/pt/qgroupnorm_test.py | 29 +-- .../pt/qinstancenorm_test.py | 27 +-- .../pt/qinterpolate_test.py | 15 +- .../operator_benchmark/pt/qlayernorm_test.py | 22 +- .../operator_benchmark/pt/qlinear_test.py | 12 +- .../operator_benchmark/pt/qobserver_test.py | 13 +- .../operator_benchmark/pt/qpool_test.py | 8 +- benchmarks/operator_benchmark/pt/qrnn_test.py | 27 ++- .../pt/qtensor_method_test.py | 28 +-- .../pt/quantization_test.py | 195 ++++++++++++------ .../operator_benchmark/pt/qunary_test.py | 20 +- .../operator_benchmark/pt/remainder_test.py | 9 +- .../operator_benchmark/pt/softmax_test.py | 8 +- .../operator_benchmark/pt/split_test.py | 10 +- benchmarks/operator_benchmark/pt/sum_test.py | 9 +- .../operator_benchmark/pt/tensor_to_test.py | 16 +- .../operator_benchmark/pt/unary_test.py | 59 ++++-- torch/nn/quantized/functional.py | 4 +- 57 files changed, 898 insertions(+), 531 deletions(-) diff --git a/benchmarks/operator_benchmark/README.md b/benchmarks/operator_benchmark/README.md index 9cdd46a4ea21..2f170ab847dd 100644 --- a/benchmarks/operator_benchmark/README.md +++ b/benchmarks/operator_benchmark/README.md @@ -60,13 +60,15 @@ add_short_configs = op_bench.cross_product_configs( ) class AddBenchmark(op_bench.TorchBenchmarkBase): - def init(self, M, N, K): - self.input_one = torch.rand(M, N, K) - self.input_two = torch.rand(M, N, K) + def init(self, M, N, K, device): + self.inputs = { + "input_one": torch.rand(M, N, K, device=device, requires_grad=self.auto_set()), + "input_two": torch.rand(M, N, K, device=device, requires_grad=self.auto_set()) + } self.set_module_name("add") - def forward(self): - return torch.add(self.input_one, self.input_two) + def forward(self, input_one, input_two): + return torch.add(input_one, input_two) op_bench.generate_pt_test(add_short_configs, AddBenchmark) ``` @@ -174,14 +176,15 @@ add_short_configs = op_bench.config_list( ) class AddBenchmark(op_bench.TorchBenchmarkBase): + def init(self, M, N, K, device): + self.inputs = { + "input_one": torch.rand(M, N, K, device=device, requires_grad=self.auto_set()), + "input_two": torch.rand(M, N, K, device=device, requires_grad=self.auto_set()) + } + self.set_module_name("add") -    def init(self, M, N, K): -        self.input_one = torch.rand(M, N, K) -        self.input_two = torch.rand(M, N, K) -        self.set_module_name("add") - -    def forward(self): -        return torch.add(self.input_one, self.input_two) + def forward(self, input_one, input_two): + return torch.add(input_one, input_two) op_bench.generate_pt_test(add_long_configs + add_short_configs, AddBenchmark) @@ -218,26 +221,28 @@ Let's look at it in detail: #### Part 2. Create Tensors and Add Computation After inputs are provided, we now look at adding the computation of an operator. Adding a new operator requires implementing a new `TorchBenchmarkBase` subclass. Every new class is required to implement 2 methods: -* `init` is used to create tensors based on the inputs we provided before. In this example, the parameters to `init` are `M, N, and K` which have been specified in the input configuration. -* `forward` includes the operator to be tested and the computation based on the created tensors in `init`. Besides the object itself, it doesn't take any additional parameters.  +* `init` is used to create tensors based on the inputs we provided before. In this example, the parameters to `init` are `M, N, and K` which have been specified in the input configuration. `init` also packed all the needed inputs together into a dictionary `self.inputs` which will be provided to `forward` as arguments for running the benchmark. +* `forward` includes the operator to be tested and the computation based on the created tensors in `init`. Apart from `self`, the order of the arguments must match the entries specified in `self.inputs`.   The example below shows the code for `torch.add`:   ``` # Given one set of M, N, K, the init method creates input tensors based on # that. The forward method does torch.add calculation on those input tensors. -class AddBenchmark(op_bench.TorchBenchmarkBase): -    def init(self, M, N, K): +class AddBenchmark(op_bench.TorchBenchmarkBase): + def init(self, M, N, K, device): # this is the method where you need to create tensors # M, N, and K can be in different order, but they must match with # names in the configs. -        self.input_one = torch.rand(M, N, K) -        self.input_two = torch.rand(M, N, K) -        self.set_module_name("add") + self.inputs = { + "input_one": torch.rand(M, N, K, device=device, requires_grad=self.auto_set()), + "input_two": torch.rand(M, N, K, device=device, requires_grad=self.auto_set()) + } + self.set_module_name("add") -    def forward(self): + def forward(self, input_one, input_two): # this is the method to have operator and do computation -        return torch.add(self.input_one, self.input_two) + return torch.add(input_one, input_two) ``` #### Part 3. Register Tests With the Benchmark Suite @@ -336,13 +341,14 @@ unary_ops_list = op_bench.op_list( ) class UnaryOpBenchmark(op_bench.TorchBenchmarkBase): + def init(self, M, N, device, op_func): + self.inputs = { + "input": torch.rand(M, N, device=device) + } + self.op_func = op_func -    def init(self, M, N, op_func): -        self.input_one = torch.rand(M, N) -        self.op_func = op_func - -    def forward(self): -        return self.op_func(self.input_one) + def forward(self, input): + return self.op_func(input) op_bench.generate_pt_tests_from_op_list(unary_ops_list, unary_ops_configs, UnaryOpBenchmark) @@ -371,20 +377,21 @@ unary_ops_list = op_bench.op_list( In this example, both operators share the same input so we only need to implement one TorchBenchmakrBase subclass.  Every new subclass is required to implement 3 methods: * `init` is used to create tensors and set the operator name and function. In this example, the parameters to `init` are `M`, `N`, and `op_func` which have been specified in the configurations. -* `forward` includes the operator to be tested and the computation based on the created tensors in `init`. Besides the object itself, it doesn't take any additional parameters.  +* `forward` includes the operator to be tested and the computation based on the created tensors in `init`. Apart from `self`, the order of the arguments must match the entries specified in `self.inputs`. Here is the code for `abs` and `acos`: ``` class UnaryOpBenchmark(op_bench.TorchBenchmarkBase): - -    def init(self, M, N, op_func): + def init(self, M, N, device, op_func): # The M and N match with the attr_names in the input configuration # The op_func matches with the attr_name in the ops configuration -        self.input_one = torch.rand(M, N) -        self.op_func = op_func + self.inputs = { + "input": torch.rand(M, N, device=device) + } + self.op_func = op_func -    def forward(self): -        return self.op_func(self.input_one) + def forward(self, input): + return self.op_func(input) ``` #### Part 3. Register a List of Operators diff --git a/benchmarks/operator_benchmark/benchmark_all_other_test.py b/benchmarks/operator_benchmark/benchmark_all_other_test.py index 4ea7ab47a4c2..adaf8a09ee96 100644 --- a/benchmarks/operator_benchmark/benchmark_all_other_test.py +++ b/benchmarks/operator_benchmark/benchmark_all_other_test.py @@ -2,9 +2,10 @@ from pt import ( # noqa add_test, as_strided_test, batchnorm_test, binary_test, cat_test, # noqa channel_shuffle_test, chunk_test, conv_test, diag_test, embeddingbag_test, # noqa - fill_test, gather_test, linear_test, matmul_test, pool_test, # noqa + fill_test, gather_test, linear_test, matmul_test, nan_to_num_test, pool_test, # noqa softmax_test, hardsigmoid_test, hardswish_test, layernorm_test, # noqa - groupnorm_test, instancenorm_test # noqa + groupnorm_test, instancenorm_test, remainder_test, softmax_test, # noqa + split_test, sum_test, tensor_to_test # noqa ) if __name__ == "__main__": diff --git a/benchmarks/operator_benchmark/benchmark_all_quantized_test.py b/benchmarks/operator_benchmark/benchmark_all_quantized_test.py index 076a2685f61e..d0f5f9ff7896 100644 --- a/benchmarks/operator_benchmark/benchmark_all_quantized_test.py +++ b/benchmarks/operator_benchmark/benchmark_all_quantized_test.py @@ -18,6 +18,7 @@ quantization_test, qunary_test, qembedding_pack_test, + qembeddingbag_test, ) diff --git a/benchmarks/operator_benchmark/benchmark_caffe2.py b/benchmarks/operator_benchmark/benchmark_caffe2.py index de56e00fa225..4fb7fffb5a5d 100644 --- a/benchmarks/operator_benchmark/benchmark_caffe2.py +++ b/benchmarks/operator_benchmark/benchmark_caffe2.py @@ -93,6 +93,10 @@ def test_name(self, name_type="long", **kargs): Caffe2BenchmarkBase.test_index += 1 return name + def extract_inputs_tuple(self): + # add a dummy function here to match the interface of TorchBenchmarkBase + pass + class Caffe2OperatorTestCase(object): """ This class includes all the information needed to benchmark an operator. diff --git a/benchmarks/operator_benchmark/benchmark_core.py b/benchmarks/operator_benchmark/benchmark_core.py index c519bc24ef97..10d08b100d66 100644 --- a/benchmarks/operator_benchmark/benchmark_core.py +++ b/benchmarks/operator_benchmark/benchmark_core.py @@ -118,6 +118,7 @@ def _build_test(configs, bench_op, OperatorTestCase, run_backward, op_name_funct op._set_backward_test(run_backward) op.init(**init_dict) + op.extract_inputs_tuple() if not run_backward: for _, attr in vars(op).items(): diff --git a/benchmarks/operator_benchmark/benchmark_pytorch.py b/benchmarks/operator_benchmark/benchmark_pytorch.py index 1c5a905f2b75..2203a0af2ec3 100644 --- a/benchmarks/operator_benchmark/benchmark_pytorch.py +++ b/benchmarks/operator_benchmark/benchmark_pytorch.py @@ -10,7 +10,7 @@ microbenchmarks. """ -class TorchBenchmarkBase(object): +class TorchBenchmarkBase(torch.nn.Module): """ This is a base class used to create Pytorch operator benchmark. module_name is the name of the operator being benchmarked. test_name is the name (it's created by concatenating all the @@ -18,8 +18,8 @@ class TorchBenchmarkBase(object): """ def __init__(self): + super(TorchBenchmarkBase, self).__init__() self.user_given_name = None - self._jit_forward = None self._pass_count = 0 self._num_inputs_require_grads = 0 @@ -49,32 +49,26 @@ def auto_set(self): self._auto_set_counter += 1 return (self._pass_count == self._auto_set_counter) - def forward(self): - pass + def extract_inputs_tuple(self): + self.inputs_tuple = tuple(self.inputs.values()) - def _wrap_forward(self, foo): - """ The function passed to JIT trace must have at least one argument, - this function is to wrap the forward method to meet that requirement. - _consume op is used to avoid the dead-code-elimination optimization - in JIT. - """ - return torch.ops.operator_benchmark._consume(self.forward()) - - def _generate_jit_forward_graph(self): - """ generate a graph for the forward function via tracing - """ + @torch.jit.export + def get_inputs(self): + # Need to convert the inputs to tuple outside of JIT so that + # JIT can infer the size of the inputs. + return self.inputs_tuple - func = torch.jit.trace(self._wrap_forward, torch.rand(1)) - place_holder = torch.rand(1) # noqa + @torch.jit.export + def forward_impl(self): + # This is to supply the inputs to the forward function which + # will be called in both the eager and JIT mode of local runs + return self.forward(*self.get_inputs()) - @torch.jit.script - def _jit_forward_graph(iters, place_holder): - # type: (int, Tensor) - result = torch.jit.annotate(torch.Tensor, place_holder) - for _ in range(iters): - result = func(place_holder) - return result - return _jit_forward_graph + @torch.jit.export + def forward_consume(self, iters: int): + # _consume is used to avoid the dead-code-elimination optimization + for _ in range(iters): + torch.ops.operator_benchmark._consume(self.forward_impl()) def module_name(self): """ this is used to label the operator being benchmarked @@ -121,13 +115,20 @@ def __init__(self, op_bench, test_config): self.place_holder_tensor = torch.ones(1) self.framework = "PyTorch" self.time_series = [] + self._jit_forward_graph = None + + def _generate_jit_forward_graph(self): + """ generate a graph for the forward function via scripting + """ + scripted_op_bench = torch.jit.script(self.op_bench) + return scripted_op_bench.forward_consume def run_jit_forward(self, num_runs, print_per_iter=False, cuda_sync=False): """ Run the forward path of an op with JIT mode """ - if self.op_bench._jit_forward is None: - self.op_bench._jit_forward = self.op_bench._generate_jit_forward_graph() - self.op_bench._jit_forward(num_runs, self.place_holder_tensor) + if self._jit_forward_graph is None: + self._jit_forward_graph = self._generate_jit_forward_graph() + self._jit_forward_graph(num_runs) def _print_per_iter(self): # print last 50 values @@ -148,14 +149,14 @@ def run_forward(self, num_runs, print_per_iter, cuda_sync): if print_per_iter: for _ in range(num_runs): start_time = time.time() - self.output = self.op_bench.forward() + self.output = self.op_bench.forward_impl() if cuda_sync: torch.cuda.synchronize(torch.cuda.current_device()) end_time = time.time() self.time_series.append((end_time - start_time) * 1e3) else: for _ in range(num_runs): - self.output = self.op_bench.forward() + self.output = self.op_bench.forward_impl() if cuda_sync: torch.cuda.synchronize(torch.cuda.current_device()) diff --git a/benchmarks/operator_benchmark/benchmark_runner.py b/benchmarks/operator_benchmark/benchmark_runner.py index 1a3ec19d7ece..b9347364428e 100644 --- a/benchmarks/operator_benchmark/benchmark_runner.py +++ b/benchmarks/operator_benchmark/benchmark_runner.py @@ -10,14 +10,12 @@ This is the main function for running performance microbenchmark tests. It also registers existing benchmark tests via Python module imports. """ +parser = argparse.ArgumentParser( + description="Run microbenchmarks.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, +) - -def main(): - parser = argparse.ArgumentParser( - description="Run microbenchmarks.", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - +def parse_args(): parser.add_argument( '--tag_filter', help='tag_filter can be used to run the shapes which matches the tag. (all is used to run all the shapes)', @@ -145,6 +143,10 @@ def main(): if args.mkl_num_threads: benchmark_utils.set_mkl_threads(args.mkl_num_threads) + return args + +def main(): + args = parse_args() benchmark_core.BenchmarkRunner(args).run() diff --git a/benchmarks/operator_benchmark/benchmark_test_generator.py b/benchmarks/operator_benchmark/benchmark_test_generator.py index 6dd8150dfccd..ec60c33c205a 100644 --- a/benchmarks/operator_benchmark/benchmark_test_generator.py +++ b/benchmarks/operator_benchmark/benchmark_test_generator.py @@ -37,3 +37,7 @@ def forward(self): """ for op in ops_list: _register_test(configs, pt_bench_op, create_pytorch_op_test_case, False, op) + +def generate_pt_gradient_tests_from_op_list(ops_list, configs, pt_bench_op): + for op in ops_list: + _register_test(configs, pt_bench_op, create_pytorch_op_test_case, True, op) diff --git a/benchmarks/operator_benchmark/pt/add_test.py b/benchmarks/operator_benchmark/pt/add_test.py index 911c44b5436a..de0277fc3de7 100644 --- a/benchmarks/operator_benchmark/pt/add_test.py +++ b/benchmarks/operator_benchmark/pt/add_test.py @@ -29,12 +29,14 @@ class AddBenchmark(op_bench.TorchBenchmarkBase): def init(self, M, N, K, device): - self.input_one = torch.rand(M, N, K, device=device, requires_grad=self.auto_set()) - self.input_two = torch.rand(M, N, K, device=device, requires_grad=self.auto_set()) + self.inputs = { + "input_one": torch.rand(M, N, K, device=device, requires_grad=self.auto_set()), + "input_two": torch.rand(M, N, K, device=device, requires_grad=self.auto_set()) + } self.set_module_name("add") - def forward(self): - return torch.add(self.input_one, self.input_two) + def forward(self, input_one, input_two): + return torch.add(input_one, input_two) # The generated test names based on add_short_configs will be in the following pattern: # add_M8_N16_K32_devicecpu @@ -53,13 +55,15 @@ def forward(self): class AddmmBenchmark(op_bench.TorchBenchmarkBase): def init(self, M, N, K, device): - self.input_one = torch.rand(M, K, device=device, requires_grad=self.auto_set()) - self.mat1 = torch.rand(M, N, device=device, requires_grad=self.auto_set()) - self.mat2 = torch.rand(N, K, device=device, requires_grad=self.auto_set()) + self.inputs = { + "input_one": torch.rand(M, K, device=device, requires_grad=self.auto_set()), + "mat1": torch.rand(M, N, device=device, requires_grad=self.auto_set()), + "mat2": torch.rand(N, K, device=device, requires_grad=self.auto_set()) + } self.set_module_name("addmm") - def forward(self): - return torch.addmm(self.input_one, self.mat1, self.mat2) + def forward(self, input_one, mat1, mat2): + return torch.addmm(input_one, mat1, mat2) op_bench.generate_pt_test(add_long_configs + add_short_configs, AddmmBenchmark) op_bench.generate_pt_gradient_test(add_long_configs + add_short_configs, AddmmBenchmark) @@ -70,13 +74,15 @@ def forward(self): class AddrBenchmark(op_bench.TorchBenchmarkBase): def init(self, M, N, device, dtype): - self.input_one = torch.rand((M, N), device=device, requires_grad=self.auto_set(), dtype=dtype) - self.vec1 = torch.rand((M,), device=device, requires_grad=self.auto_set(), dtype=dtype) - self.vec2 = torch.rand((N,), device=device, requires_grad=self.auto_set(), dtype=dtype) + self.inputs = { + "input_one": torch.rand((M, N), device=device, requires_grad=self.auto_set(), dtype=dtype), + "vec1": torch.rand((M,), device=device, requires_grad=self.auto_set(), dtype=dtype), + "vec2": torch.rand((N,), device=device, requires_grad=self.auto_set(), dtype=dtype) + } self.set_module_name("addr") - def forward(self): - return torch.addr(self.input_one, self.vec1, self.vec2) + def forward(self, input_one, vec1, vec2): + return torch.addr(input_one, vec1, vec2) addr_configs = op_bench.cross_product_configs( M=[8, 256], @@ -95,13 +101,15 @@ def forward(self): class AddbmmBenchmark(op_bench.TorchBenchmarkBase): def init(self, B, M, N, K, device): - self.input_one = torch.rand((M, N), device=device, requires_grad=self.auto_set()) - self.batch1 = torch.rand((B, M, K), device=device, requires_grad=self.auto_set()) - self.batch2 = torch.rand((B, K, N,), device=device, requires_grad=self.auto_set()) + self.inputs = { + "input_one": torch.rand((M, N), device=device, requires_grad=self.auto_set()), + "batch1": torch.rand((B, M, K), device=device, requires_grad=self.auto_set()), + "batch2": torch.rand((B, K, N,), device=device, requires_grad=self.auto_set()) + } self.set_module_name("addbmm") - def forward(self): - return torch.addbmm(self.input_one, self.batch1, self.batch2) + def forward(self, input_one, batch1, batch2): + return torch.addbmm(input_one, batch1, batch2) addbmm_configs = op_bench.cross_product_configs( B=[2, 100], diff --git a/benchmarks/operator_benchmark/pt/as_strided_test.py b/benchmarks/operator_benchmark/pt/as_strided_test.py index a43702c15e22..77eff29811be 100644 --- a/benchmarks/operator_benchmark/pt/as_strided_test.py +++ b/benchmarks/operator_benchmark/pt/as_strided_test.py @@ -1,5 +1,6 @@ import operator_benchmark as op_bench import torch +from typing import List """Microbenchmarks for as_strided operator""" @@ -32,15 +33,19 @@ class As_stridedBenchmark(op_bench.TorchBenchmarkBase): def init(self, M, N, size, stride, storage_offset, device): - self.input_one = torch.rand(M, N, device=device) - self.size = size - self.stride = stride - self.storage_offset = storage_offset + self.inputs = { + "input_one": torch.rand(M, N, device=device), + "size": size, + "stride": stride, + "storage_offset": storage_offset + } self.set_module_name('as_strided') - def forward(self): + def forward( + self, input_one, size: List[int], stride: List[int], storage_offset: int + ): return torch.as_strided( - self.input_one, self.size, self.stride, self.storage_offset) + input_one, size, stride, storage_offset) op_bench.generate_pt_test(as_strided_configs_short + as_strided_configs_long, diff --git a/benchmarks/operator_benchmark/pt/batchnorm_test.py b/benchmarks/operator_benchmark/pt/batchnorm_test.py index 7257be36b9f1..816bdcc55342 100644 --- a/benchmarks/operator_benchmark/pt/batchnorm_test.py +++ b/benchmarks/operator_benchmark/pt/batchnorm_test.py @@ -28,15 +28,17 @@ class BatchNormBenchmark(op_bench.TorchBenchmarkBase): def init(self, M, N, K, device): - self.input_one = torch.rand(M, N, K, device=device, requires_grad=self.auto_set()) - self.mean = torch.rand(N, device=device) - self.var = torch.rand(N, device=device) - self.weight = torch.rand(N, device=device) - self.bias = torch.rand(N, device=device) + self.inputs = { + "input_one": torch.rand(M, N, K, device=device, requires_grad=self.auto_set()), + "mean": torch.rand(N, device=device), + "var": torch.rand(N, device=device), + "weight": torch.rand(N, device=device), + "bias": torch.rand(N, device=device) + } self.set_module_name("batchnorm") - def forward(self): - return F.batch_norm(self.input_one, self.mean, self.var, self.weight, self.bias) + def forward(self, input_one, mean, var, weight, bias): + return F.batch_norm(input_one, mean, var, weight, bias) op_bench.generate_pt_test(batchnorm_configs_short + batchnorm_configs_long, BatchNormBenchmark) diff --git a/benchmarks/operator_benchmark/pt/binary_test.py b/benchmarks/operator_benchmark/pt/binary_test.py index bd177775764e..9650392deae7 100644 --- a/benchmarks/operator_benchmark/pt/binary_test.py +++ b/benchmarks/operator_benchmark/pt/binary_test.py @@ -29,12 +29,14 @@ class BinaryOpBcastBenchmark(op_bench.TorchBenchmarkBase): def init(self, in_one, in_two, dtype, device, op_func): - self.in_one = torch.randn(in_one, device=device).to(dtype=dtype) - self.in_two = torch.randn(in_two, device=device).to(dtype=dtype) + self.inputs = { + "in_one": torch.randn(in_one, device=device).to(dtype=dtype), + "in_two": torch.randn(in_two, device=device).to(dtype=dtype) + } self.op_func = op_func - def forward(self): - return self.op_func(self.in_one, self.in_two) + def forward(self, in_one, in_two): + return self.op_func(in_one, in_two) op_bench.generate_pt_tests_from_op_list(binary_ops_bcast_list, @@ -42,12 +44,15 @@ def forward(self): BinaryOpBcastBenchmark) +def copy(in1, in2): + return in1.copy_(in2) + # Benchmark ops performance without broadcast binary_ops_list = op_bench.op_list( attr_names=['op_name', 'op_func'], attrs=[ ['add', torch.add], - ['copy_', lambda in1, in2: in1.copy_(in2)], + ['copy_', copy], ], ) @@ -79,12 +84,14 @@ def forward(self): class BinaryOpBenchmark(op_bench.TorchBenchmarkBase): def init(self, M, N, K, device, dtype_one, dtype_two, op_func): - self.input_one = torch.randn(M, N, K, device=device).to(dtype=dtype_one) - self.input_two = torch.randn(M, N, K, device=device).to(dtype=dtype_two) + self.inputs = { + "input_one": torch.randn(M, N, K, device=device).to(dtype=dtype_one), + "input_two": torch.randn(M, N, K, device=device).to(dtype=dtype_two) + } self.op_func = op_func - def forward(self): - return self.op_func(self.input_one, self.input_two) + def forward(self, input_one, input_two): + return self.op_func(input_one, input_two) op_bench.generate_pt_tests_from_op_list(binary_ops_list, diff --git a/benchmarks/operator_benchmark/pt/cat_test.py b/benchmarks/operator_benchmark/pt/cat_test.py index 97df91087366..c1022f296a2f 100644 --- a/benchmarks/operator_benchmark/pt/cat_test.py +++ b/benchmarks/operator_benchmark/pt/cat_test.py @@ -1,6 +1,7 @@ import operator_benchmark as op_bench import torch import random +from typing import List """Microbenchmarks for Cat operator""" @@ -78,16 +79,19 @@ class CatBenchmark(op_bench.TorchBenchmarkBase): def init(self, sizes, N, dim, device): random.seed(42) - self.inputs = [] + inputs = [] for i in range(N): current_sizes = [old_size() if callable(old_size) else old_size for old_size in sizes] - self.inputs.append(torch.rand(current_sizes, device=device)) - self.dim = dim + inputs.append(torch.rand(current_sizes, device=device)) + self.inputs = { + "inputs": inputs, + "dim": dim + } self.set_module_name('cat') - def forward(self): - return torch.cat(self.inputs, dim=self.dim) + def forward(self, inputs: List[torch.Tensor], dim: int): + return torch.cat(inputs, dim=dim) op_bench.generate_pt_test(cat_configs_short + diff --git a/benchmarks/operator_benchmark/pt/channel_shuffle_test.py b/benchmarks/operator_benchmark/pt/channel_shuffle_test.py index 258bb6d69c04..87163f004b2d 100644 --- a/benchmarks/operator_benchmark/pt/channel_shuffle_test.py +++ b/benchmarks/operator_benchmark/pt/channel_shuffle_test.py @@ -36,16 +36,19 @@ class ChannelSHuffleBenchmark(op_bench.TorchBenchmarkBase): def init(self, batch_size, channels_per_group, height, width, groups, channel_last): - self.groups = groups channels = channels_per_group * groups data_shape = (batch_size, channels, height, width) - self.input_data = torch.rand(data_shape) + input_data = torch.rand(data_shape) if channel_last: - self.input_data = self.input_data.contiguous(memory_format=torch.channels_last) + input_data = input_data.contiguous(memory_format=torch.channels_last) + self.inputs = { + "input_data": input_data, + "groups": groups + } self.set_module_name('channel_shuffle') - def forward(self): - return torch.channel_shuffle(self.input_data, self.groups) + def forward(self, input_data, groups: int): + return torch.channel_shuffle(input_data, groups) op_bench.generate_pt_test(channel_shuffle_short_configs + channel_shuffle_long_configs, diff --git a/benchmarks/operator_benchmark/pt/chunk_test.py b/benchmarks/operator_benchmark/pt/chunk_test.py index 885301dfdcb0..6c1148dbcdaa 100644 --- a/benchmarks/operator_benchmark/pt/chunk_test.py +++ b/benchmarks/operator_benchmark/pt/chunk_test.py @@ -30,12 +30,14 @@ class ChunkBenchmark(op_bench.TorchBenchmarkBase): def init(self, M, N, chunks, device): - self.input_one = torch.rand(M, N, device=device) - self.chunks = chunks - self.set_module_name('chunk') - - def forward(self): - return torch.chunk(self.input_one, self.chunks) + self.inputs = { + "input_one": torch.rand(M, N, device=device), + "chunks": chunks + } + self.set_module_name("chunk") + + def forward(self, input_one, chunks: int): + return torch.chunk(input_one, chunks) op_bench.generate_pt_test(chunk_short_configs + chunks_long_configs, diff --git a/benchmarks/operator_benchmark/pt/clip_ranges_test.py b/benchmarks/operator_benchmark/pt/clip_ranges_test.py index d2c0d575647b..3b6b95d93786 100644 --- a/benchmarks/operator_benchmark/pt/clip_ranges_test.py +++ b/benchmarks/operator_benchmark/pt/clip_ranges_test.py @@ -35,13 +35,14 @@ class ClipRangesBenchmark(op_bench.TorchBenchmarkBase): def init(self, LENGTH, M, N, MAX_LENGTH, device, dtype): - self.input = torch.rand(LENGTH, M, N, device=device).type(dtype) - self.max_length = MAX_LENGTH + self.inputs = { + "input": torch.rand(LENGTH, M, N, device=device).type(dtype), + "max_length": MAX_LENGTH + } self.set_module_name("clip_ranges") - def forward(self): - output = torch.ops.fb.clip_ranges(self.input, self.max_length) - return output + def forward(self, input, max_length: int): + return torch.ops.fb.clip_ranges(input, max_length) op_bench.generate_pt_test( diff --git a/benchmarks/operator_benchmark/pt/conv_test.py b/benchmarks/operator_benchmark/pt/conv_test.py index 6126cc0f6d3c..28b511815716 100644 --- a/benchmarks/operator_benchmark/pt/conv_test.py +++ b/benchmarks/operator_benchmark/pt/conv_test.py @@ -11,22 +11,26 @@ class Conv1dBenchmark(op_bench.TorchBenchmarkBase): def init(self, IC, OC, kernel, stride, N, L, device): - self.input = torch.rand(N, IC, L, device=device) + self.inputs = { + "input": torch.rand(N, IC, L, device=device, requires_grad=self.auto_set()) + } self.conv1d = nn.Conv1d(IC, OC, kernel, stride=stride).to(device=device) self.set_module_name('Conv1d') - def forward(self): - return self.conv1d(self.input) + def forward(self, input): + return self.conv1d(input) class ConvTranspose1dBenchmark(op_bench.TorchBenchmarkBase): def init(self, IC, OC, kernel, stride, N, L, device): - self.input = torch.rand(N, IC, L, device=device) + self.inputs = { + "input": torch.rand(N, IC, L, device=device) + } self.convtranspose1d = nn.ConvTranspose1d(IC, OC, kernel, stride=stride).to(device=device) self.set_module_name('ConvTranspose1d') - def forward(self): - return self.convtranspose1d(self.input) + def forward(self, input): + return self.convtranspose1d(input) op_bench.generate_pt_test(configs.conv_1d_configs_short + configs.conv_1d_configs_long, @@ -42,24 +46,28 @@ def forward(self): class Conv2dBenchmark(op_bench.TorchBenchmarkBase): def init(self, IC, OC, kernel, stride, N, H, W, G, pad, device): - self.input = torch.rand(N, IC, H, W, device=device) + self.inputs = { + "input": torch.rand(N, IC, H, W, device=device) + } self.conv2d = nn.Conv2d( IC, OC, kernel, stride=stride, groups=G, padding=pad).to(device=device) self.set_module_name('Conv2d') - def forward(self): - return self.conv2d(self.input) + def forward(self, input): + return self.conv2d(input) class ConvTranspose2dBenchmark(op_bench.TorchBenchmarkBase): def init(self, IC, OC, kernel, stride, N, H, W, G, pad, device): - self.input = torch.rand(N, IC, H, W, device=device) + self.inputs = { + "input": torch.rand(N, IC, H, W, device=device) + } self.convtranspose2d = nn.ConvTranspose2d( IC, OC, kernel, stride=stride, groups=G, padding=pad).to(device=device) self.set_module_name('ConvTranspose2d') - def forward(self): - return self.convtranspose2d(self.input) + def forward(self, input): + return self.convtranspose2d(input) op_bench.generate_pt_test(configs.conv_2d_configs_short + configs.conv_2d_configs_long, @@ -74,22 +82,26 @@ def forward(self): class Conv3dBenchmark(op_bench.TorchBenchmarkBase): def init(self, IC, OC, kernel, stride, N, D, H, W, device): - self.input = torch.rand(N, IC, D, H, W, device=device) + self.inputs = { + "input": torch.rand(N, IC, D, H, W, device=device) + } self.conv3d = nn.Conv3d(IC, OC, kernel, stride=stride).to(device=device) self.set_module_name('Conv3d') - def forward(self): - return self.conv3d(self.input) + def forward(self, input): + return self.conv3d(input) class ConvTranspose3dBenchmark(op_bench.TorchBenchmarkBase): def init(self, IC, OC, kernel, stride, N, D, H, W, device): - self.input = torch.rand(N, IC, D, H, W, device=device) + self.inputs = { + "input": torch.rand(N, IC, D, H, W, device=device) + } self.convtranspose3d = nn.ConvTranspose3d(IC, OC, kernel, stride=stride).to(device=device) self.set_module_name('ConvTranspose3d') - def forward(self): - return self.convtranspose3d(self.input) + def forward(self, input): + return self.convtranspose3d(input) op_bench.generate_pt_test(configs.conv_3d_configs_short, Conv3dBenchmark) diff --git a/benchmarks/operator_benchmark/pt/diag_test.py b/benchmarks/operator_benchmark/pt/diag_test.py index 3fa105442895..79ad86d29510 100644 --- a/benchmarks/operator_benchmark/pt/diag_test.py +++ b/benchmarks/operator_benchmark/pt/diag_test.py @@ -22,13 +22,19 @@ class DiagBenchmark(op_bench.TorchBenchmarkBase): def init(self, dim, M, N, diagonal, out, device): - self.input = torch.rand(M, N, device=device) if dim == 2 else torch.rand(M, device=device) - self.diagonal = diagonal - self.out = torch.tensor((),) if out else None + self.inputs = { + "input": torch.rand(M, N, device=device) if dim == 2 else torch.rand(M, device=device), + "diagonal": diagonal, + "out": out, + "out_tensor": torch.tensor((),) + } self.set_module_name('diag') - def forward(self): - return torch.diag(self.input, diagonal=self.diagonal, out=self.out) + def forward(self, input, diagonal: int, out: bool, out_tensor): + if out: + return torch.diag(input, diagonal=diagonal, out=out_tensor) + else: + return torch.diag(input, diagonal=diagonal) op_bench.generate_pt_test(diag_configs_short, DiagBenchmark) diff --git a/benchmarks/operator_benchmark/pt/embeddingbag_test.py b/benchmarks/operator_benchmark/pt/embeddingbag_test.py index c93fbf09206c..a8c100a79721 100644 --- a/benchmarks/operator_benchmark/pt/embeddingbag_test.py +++ b/benchmarks/operator_benchmark/pt/embeddingbag_test.py @@ -14,13 +14,16 @@ def init(self, embeddingbags, dim, mode, input_size, offset, sparse, include_las include_last_offset=include_last_offset, sparse=sparse).to(device=device) numpy.random.seed((1 << 32) - 1) - self.input = torch.tensor(numpy.random.randint(0, embeddingbags, input_size), device=device).long() offsets = torch.LongTensor([offset], device=device) - self.offset = torch.cat((offsets, torch.tensor([self.input.size(0)], dtype=torch.long)), 0) + input = torch.tensor(numpy.random.randint(0, embeddingbags, input_size), device=device).long() + self.inputs = { + "input": input, + "offset": torch.cat((offsets, torch.tensor([input.size(0)], dtype=torch.long)), 0) + } self.set_module_name('embeddingbag') - def forward(self): - return self.embedding(self.input, self.offset) + def forward(self, input, offset): + return self.embedding(input, offset) op_bench.generate_pt_test(configs.embeddingbag_short_configs, EmbeddingBagBenchmark) op_bench.generate_pt_gradient_test(configs.embeddingbag_short_configs, EmbeddingBagBenchmark) diff --git a/benchmarks/operator_benchmark/pt/fill_test.py b/benchmarks/operator_benchmark/pt/fill_test.py index 5a162db9f5f5..97f59394a66a 100644 --- a/benchmarks/operator_benchmark/pt/fill_test.py +++ b/benchmarks/operator_benchmark/pt/fill_test.py @@ -28,11 +28,13 @@ class Fill_Benchmark(op_bench.TorchBenchmarkBase): def init(self, N, device, dtype): - self.input_one = torch.zeros(N, device=device).type(dtype) + self.inputs = { + "input_one": torch.zeros(N, device=device).type(dtype) + } self.set_module_name("fill_") - def forward(self): - return self.input_one.fill_(10) + def forward(self, input_one): + return input_one.fill_(10) op_bench.generate_pt_test(fill_short_configs + fill_long_configs, diff --git a/benchmarks/operator_benchmark/pt/gather_test.py b/benchmarks/operator_benchmark/pt/gather_test.py index 509c1b937c3a..6538cb3a8b90 100644 --- a/benchmarks/operator_benchmark/pt/gather_test.py +++ b/benchmarks/operator_benchmark/pt/gather_test.py @@ -30,15 +30,17 @@ class GatherBenchmark(op_bench.TorchBenchmarkBase): def init(self, M, N, dim, device): - self.input_one = torch.rand(M, N, device=device) - self.dim = dim min_val = M if dim == 0 else N numpy.random.seed((1 << 32) - 1) - self.index = torch.tensor(numpy.random.randint(0, min_val, (M, N)), device=device) + self.inputs = { + "input_one": torch.rand(M, N, device=device), + "dim": dim, + "index": torch.tensor(numpy.random.randint(0, min_val, (M, N)), device=device) + } self.set_module_name("gather") - def forward(self): - return torch.gather(self.input_one, self.dim, self.index) + def forward(self, input_one, dim: int, index): + return torch.gather(input_one, dim, index) op_bench.generate_pt_test(gather_configs_short + gather_configs_long, diff --git a/benchmarks/operator_benchmark/pt/groupnorm_test.py b/benchmarks/operator_benchmark/pt/groupnorm_test.py index eb941b863dc7..f360ae26b207 100644 --- a/benchmarks/operator_benchmark/pt/groupnorm_test.py +++ b/benchmarks/operator_benchmark/pt/groupnorm_test.py @@ -18,16 +18,18 @@ class GroupNormBenchmark(op_bench.TorchBenchmarkBase): def init(self, dims, num_groups): - self.X = (torch.rand(*dims) - 0.5) * 256 - self.num_groups = num_groups num_channels = dims[1] - self.weight = torch.rand(num_channels, dtype=torch.float) - self.bias = torch.rand(num_channels, dtype=torch.float) - self.eps = 1e-5 - - def forward(self): + self.inputs = { + "input": (torch.rand(*dims) - 0.5) * 256, + "num_groups": num_groups, + "weight": torch.rand(num_channels, dtype=torch.float), + "bias": torch.rand(num_channels, dtype=torch.float), + "eps": 1e-5 + } + + def forward(self, input, num_groups: int, weight, bias, eps: float): return F.group_norm( - self.X, self.num_groups, weight=self.weight, bias=self.bias, eps=self.eps) + input, num_groups, weight=weight, bias=bias, eps=eps) op_bench.generate_pt_test(groupnorm_configs_short, GroupNormBenchmark) diff --git a/benchmarks/operator_benchmark/pt/hardsigmoid_test.py b/benchmarks/operator_benchmark/pt/hardsigmoid_test.py index c3011d0a1fe4..f1161e485e72 100644 --- a/benchmarks/operator_benchmark/pt/hardsigmoid_test.py +++ b/benchmarks/operator_benchmark/pt/hardsigmoid_test.py @@ -45,11 +45,13 @@ class HardsigmoidBenchmark(op_bench.TorchBenchmarkBase): def init(self, N, C, H, W, device, op_func): - self.input_one = torch.rand(N, C, H, W, device=device) + self.inputs = { + "input_one": torch.rand(N, C, H, W, device=device) + } self.op_func = op_func() - def forward(self): - return self.op_func(self.input_one) + def forward(self, input_one): + return self.op_func(input_one) op_bench.generate_pt_tests_from_op_list(hardsigmoid_ops_list, diff --git a/benchmarks/operator_benchmark/pt/hardswish_test.py b/benchmarks/operator_benchmark/pt/hardswish_test.py index 3879679bd33b..0f1f94c0ddba 100644 --- a/benchmarks/operator_benchmark/pt/hardswish_test.py +++ b/benchmarks/operator_benchmark/pt/hardswish_test.py @@ -45,11 +45,13 @@ class HardswishBenchmark(op_bench.TorchBenchmarkBase): def init(self, N, C, H, W, device, op_func): - self.input_one = torch.rand(N, C, H, W, device=device) + self.inputs = { + "input_one": torch.rand(N, C, H, W, device=device) + } self.op_func = op_func() - def forward(self): - return self.op_func(self.input_one) + def forward(self, input_one): + return self.op_func(input_one) op_bench.generate_pt_tests_from_op_list(hardswish_ops_list, diff --git a/benchmarks/operator_benchmark/pt/instancenorm_test.py b/benchmarks/operator_benchmark/pt/instancenorm_test.py index 4eac02bc8bd8..b152a9c75303 100644 --- a/benchmarks/operator_benchmark/pt/instancenorm_test.py +++ b/benchmarks/operator_benchmark/pt/instancenorm_test.py @@ -17,15 +17,17 @@ class InstanceNormBenchmark(op_bench.TorchBenchmarkBase): def init(self, dims): - self.X = (torch.rand(*dims) - 0.5) * 256 num_channels = dims[1] - self.weight = torch.rand(num_channels, dtype=torch.float) - self.bias = torch.rand(num_channels, dtype=torch.float) - self.eps = 1e-5 - - def forward(self): + self.inputs = { + "input": (torch.rand(*dims) - 0.5) * 256, + "weight": torch.rand(num_channels, dtype=torch.float), + "bias": torch.rand(num_channels, dtype=torch.float), + "eps": 1e-5 + } + + def forward(self, input, weight, bias, eps: float): return F.instance_norm( - self.X, weight=self.weight, bias=self.bias, eps=self.eps) + input, weight=weight, bias=bias, eps=eps) op_bench.generate_pt_test(instancenorm_configs_short, InstanceNormBenchmark) diff --git a/benchmarks/operator_benchmark/pt/layernorm_test.py b/benchmarks/operator_benchmark/pt/layernorm_test.py index f0aa81a8291c..b18abf26eaf8 100644 --- a/benchmarks/operator_benchmark/pt/layernorm_test.py +++ b/benchmarks/operator_benchmark/pt/layernorm_test.py @@ -19,14 +19,17 @@ class LayerNormBenchmark(op_bench.TorchBenchmarkBase): def init(self, dims): - self.X = (torch.rand(*dims) - 0.5) * 256 - self.weight = torch.rand(*self.X.size()[1:], dtype=torch.float) - self.bias = torch.rand(*self.X.size()[1:], dtype=torch.float) - self.eps = 1e-5 - - def forward(self): + input = (torch.rand(*dims) - 0.5) * 256 + self.inputs = { + "input": input, + "weight": torch.rand(*input.size()[1:], dtype=torch.float), + "bias": torch.rand(*input.size()[1:], dtype=torch.float), + "eps": 1e-5 + } + + def forward(self, input, weight, bias, eps: float): return F.layer_norm( - self.X, self.X.size()[1:], weight=self.weight, bias=self.bias, eps=self.eps) + input, input.size()[1:], weight=weight, bias=bias, eps=eps) op_bench.generate_pt_test(layernorm_configs_short, LayerNormBenchmark) diff --git a/benchmarks/operator_benchmark/pt/linear_test.py b/benchmarks/operator_benchmark/pt/linear_test.py index f6728da67a8e..84263ed6f2d4 100644 --- a/benchmarks/operator_benchmark/pt/linear_test.py +++ b/benchmarks/operator_benchmark/pt/linear_test.py @@ -11,12 +11,14 @@ class LinearBenchmark(op_bench.TorchBenchmarkBase): def init(self, N, IN, OUT, device): - self.input_one = torch.rand(N, IN, device=device) + self.inputs = { + "input_one": torch.rand(N, IN, device=device) + } self.linear = nn.Linear(IN, OUT).to(device=device) self.set_module_name("linear") - def forward(self): - return self.linear(self.input_one) + def forward(self, input_one): + return self.linear(input_one) op_bench.generate_pt_test(configs.linear_configs_short + configs.linear_configs_long, diff --git a/benchmarks/operator_benchmark/pt/matmul_test.py b/benchmarks/operator_benchmark/pt/matmul_test.py index 0c60524b911a..e5d7d27589d4 100644 --- a/benchmarks/operator_benchmark/pt/matmul_test.py +++ b/benchmarks/operator_benchmark/pt/matmul_test.py @@ -31,14 +31,18 @@ class MatMulBenchmark(op_bench.TorchBenchmarkBase): def init(self, M, N, K, trans_a, trans_b, device): - self.input_one = torch.rand(M, N, device=device) if trans_a \ - else torch.rand(N, M, device=device).t() - self.input_two = torch.rand(N, K, device=device) if trans_b \ - else torch.rand(K, N, device=device).t() + self.inputs = { + "input_one": torch.rand(M, N, device=device) + if trans_a + else torch.rand(N, M, device=device).t(), + "input_two": torch.rand(N, K, device=device) + if trans_b + else torch.rand(K, N, device=device).t(), + } self.set_module_name("matmul") - def forward(self): - return torch.matmul(self.input_one, self.input_two) + def forward(self, input_one, input_two): + return torch.matmul(input_one, input_two) op_bench.generate_pt_test(mm_long_configs + mm_short_configs, MatMulBenchmark) diff --git a/benchmarks/operator_benchmark/pt/nan_to_num_test.py b/benchmarks/operator_benchmark/pt/nan_to_num_test.py index 72f5daf33afa..0e8d82ad781f 100644 --- a/benchmarks/operator_benchmark/pt/nan_to_num_test.py +++ b/benchmarks/operator_benchmark/pt/nan_to_num_test.py @@ -6,11 +6,19 @@ """Microbenchmarks for torch.nan_to_num / nan_to_num_ operators""" # Configs for PT torch.nan_to_num / nan_to_num_ operators + +nan_to_num_ops_list = op_bench.op_list( + attr_names=['op_name', 'op_func'], + attrs=[ + ['nan_to_num', torch.nan_to_num], + ['nan_to_num_', torch.nan_to_num_], + ], +) + nan_to_num_long_configs = op_bench.cross_product_configs( M=[32, 64, 128], N=range(32, 128, 32), dtype=[torch.float, torch.double], - op=["nan_to_num", "nan_to_num_"], replace_inf=[True, False], tags=["long"], ) @@ -20,36 +28,32 @@ M=[16, 64], N=[64, 64], dtype=[torch.float, torch.double], - op=["nan_to_num", "nan_to_num_"], replace_inf=[True, False], tags=["short"], ) class ReplaceNaNBenchmark(op_bench.TorchBenchmarkBase): - def init(self, M, N, dtype, op, replace_inf): - self.input = torch.randn(M, N, dtype=dtype) - self.input[0][0] = float("nan") - self.op = op - self.replace_inf = replace_inf + def init(self, M, N, dtype, replace_inf, op_func): + input = torch.randn(M, N, dtype=dtype) + input[0][0] = float("nan") + self.inputs = { + "input": input, + "replace_inf": replace_inf + } + self.op_func = op_func self.set_module_name("nan_to_num") - def forward(self): + def forward(self, input, replace_inf: bool): # compare inplace - if self.op == "nan_to_num": - if self.replace_inf: - output = torch.nan_to_num(self.input, nan=1.0) - else: - output = torch.nan_to_num(self.input, nan=1.0, posinf=math.inf, neginf=-math.inf) + if replace_inf: + return self.op_func(input, nan=1.0) else: - if self.replace_inf: - output = torch.nan_to_num_(self.input, nan=1.0) - else: - output = torch.nan_to_num_(self.input, nan=1.0, posinf=math.inf, neginf=-math.inf) - return output + return self.op_func(input, nan=1.0, posinf=math.inf, neginf=-math.inf) -op_bench.generate_pt_test( +op_bench.generate_pt_tests_from_op_list( + nan_to_num_ops_list, nan_to_num_long_configs + nan_to_num_short_configs, ReplaceNaNBenchmark, ) diff --git a/benchmarks/operator_benchmark/pt/pool_test.py b/benchmarks/operator_benchmark/pt/pool_test.py index 88a75522566d..f465c41a0967 100644 --- a/benchmarks/operator_benchmark/pt/pool_test.py +++ b/benchmarks/operator_benchmark/pt/pool_test.py @@ -41,13 +41,13 @@ class Pool1dBenchmark(op_bench.TorchBenchmarkBase): def init(self, kernel, stride, N, C, L, device, op_func): - self.input = torch.rand(N, C, L, device=device) - self.kernel = kernel - self.stride = stride - self.op_func = op_func(self.kernel, stride=self.stride) + self.inputs = { + "input": torch.rand(N, C, L, device=device) + } + self.op_func = op_func(kernel, stride=stride) - def forward(self): - return self.op_func(self.input) + def forward(self, input): + return self.op_func(input) op_bench.generate_pt_tests_from_op_list(pool_1d_ops_list, @@ -98,14 +98,14 @@ def forward(self): class Pool2dBenchmark(op_bench.TorchBenchmarkBase): def init(self, kernel, stride, N, C, H, W, device, op_func): - self.input = torch.rand(N, C, H, W, device=device) - self.kernel = kernel - self.stride = stride - self.op_func = op_func(self.kernel, stride=self.stride) + self.inputs = { + "input": torch.rand(N, C, H, W, device=device) + } + self.op_func = op_func(kernel, stride=stride) - def forward(self): - return self.op_func(self.input) + def forward(self, input): + return self.op_func(input) op_bench.generate_pt_tests_from_op_list(pool_2d_ops_list, @@ -158,13 +158,13 @@ def forward(self): class Pool3dBenchmark(op_bench.TorchBenchmarkBase): def init(self, kernel, stride, N, C, D, H, W, device, op_func): - self.input = torch.rand(N, C, D, H, W, device=device) - self.kernel = kernel - self.stride = stride - self.op_func = op_func(self.kernel, stride=self.stride) + self.inputs = { + "input": torch.rand(N, C, D, H, W, device=device) + } + self.op_func = op_func(kernel, stride=stride) - def forward(self): - return self.op_func(self.input) + def forward(self, input): + return self.op_func(input) op_bench.generate_pt_tests_from_op_list(pool_3d_ops_list, diff --git a/benchmarks/operator_benchmark/pt/qactivation_test.py b/benchmarks/operator_benchmark/pt/qactivation_test.py index 7ef51958dc60..d8107c0044d7 100644 --- a/benchmarks/operator_benchmark/pt/qactivation_test.py +++ b/benchmarks/operator_benchmark/pt/qactivation_test.py @@ -45,9 +45,6 @@ ('relu', nnq.functional.relu), ('relu6', torch.ops.quantized.relu6), ('functional.hardtanh', nnq.functional.hardtanh), - ('functional.hardswish', nnq.functional.hardswish), - ('functional.elu', nnq.functional.elu), - ('functional.celu', nnq.functional.celu), ('functional.hardsigmoid', nnq.functional.hardsigmoid), ('functional.leaky_relu', nnq.functional.leaky_relu), ('functional.sigmoid', torch.nn.functional.sigmoid), @@ -66,28 +63,49 @@ def _setup(self, dims, contig, dtype): self.zero_point = 0 # Quantize the tensor - self.q_input = torch.quantize_per_tensor(f_input, scale=self.scale, - zero_point=self.zero_point, - dtype=dtype) + q_input = torch.quantize_per_tensor(f_input, scale=self.scale, + zero_point=self.zero_point, + dtype=dtype) if not contig: # Make non-contiguous - new_shape = list(range(self.q_input.ndim))[::-1] - self.q_input = self.q_input.permute(new_shape) + new_shape = list(range(q_input.ndim))[::-1] + q_input = q_input.permute(new_shape) + + self.inputs = { + "q_input": q_input + } def init(self, dims, contig, inplace, dtype, op_func): self._setup(dims, contig, dtype) self.qop = op_func - def forward(self): - if self.qop in (nnq.functional.hardswish, nnq.functional.elu, - nnq.functional.celu): - return self.qop(self.q_input, scale=self.scale, zero_point=self.zero_point) - return self.qop(self.q_input) + +class QActivationBenchmark(QActivationBenchmarkBase): + def forward(self, q_input): + return self.qop(q_input) op_bench.generate_pt_tests_from_op_list(qactivation_ops, qactivation_short_configs + qactivation_long_configs, - QActivationBenchmarkBase) + QActivationBenchmark) + + +qactivation_scale_zero_point_ops = op_bench.op_list( + attrs=( + ('functional.hardswish', nnq.functional.hardswish), + ('functional.elu', nnq.functional.elu), + ('functional.celu', nnq.functional.celu), + ), + attr_names=('op_name', 'op_func'), +) + +class QActivationScaleZeroPointBenchmark(QActivationBenchmarkBase): + def forward(self, q_input): + return self.qop(q_input, scale=self.scale, zero_point=self.zero_point) + +op_bench.generate_pt_tests_from_op_list(qactivation_scale_zero_point_ops, + qactivation_short_configs + qactivation_long_configs, + QActivationScaleZeroPointBenchmark) if __name__ == "__main__": op_bench.benchmark_runner.main() diff --git a/benchmarks/operator_benchmark/pt/qarithmetic_test.py b/benchmarks/operator_benchmark/pt/qarithmetic_test.py index 87c75845900b..01be129fe597 100644 --- a/benchmarks/operator_benchmark/pt/qarithmetic_test.py +++ b/benchmarks/operator_benchmark/pt/qarithmetic_test.py @@ -1,5 +1,5 @@ import torch - +from torch._ops import ops import operator_benchmark as op_bench qarithmetic_binary_configs = op_bench.cross_product_configs( @@ -10,58 +10,77 @@ tags=('short',) ) + qarithmetic_binary_ops = op_bench.op_list( attrs=( - ('add', 'add'), - ('add_scalar', 'add_scalar'), - ('add_relu', 'add_relu'), - ('mul', 'mul'), - ('mul_scalar', 'mul_scalar'), + ('add', ops.quantized.add), + ('add_relu', ops.quantized.add_relu), + ('mul', ops.quantized.mul), ), attr_names=('op_name', 'op_func'), ) +qarithmetic_binary_scalar_ops = op_bench.op_list( + attrs=( + ('add_scalar', ops.quantized.add_scalar), + ('mul_scalar', ops.quantized.mul_scalar), + ), + attr_names=('op_name', 'op_func'), +) -r"""Base class to use QFunctional. - -Children will need to set `self.qop` to the qfunctional op under test. -I.e. `self.qop = 'add'` -""" class _QFunctionalBinaryArithmeticBenchmarkBase(op_bench.TorchBenchmarkBase): def setup(self, N, dtype, contig): self.qfunctional = torch.nn.quantized.QFunctional() # TODO: Consider more diverse shapes f_input = (torch.rand(N, N) - 0.5) * 256 - scale = 1.0 - zero_point = 0 - - self.q_input_a = torch.quantize_per_tensor(f_input, scale=scale, - zero_point=zero_point, + self.scale = 1.0 + self.zero_point = 0 + self.q_input_a = torch.quantize_per_tensor(f_input, scale=self.scale, + zero_point=self.zero_point, dtype=dtype) if not contig: permute_dims = list(range(f_input.ndim))[::-1] self.q_input_a = self.q_input_a.permute(permute_dims) - def forward(self): - return getattr(self.qfunctional, self.qop)(self.q_input_a, - self.q_input_b) - -class QFunctionalAddBenchmarkBase(_QFunctionalBinaryArithmeticBenchmarkBase): +class QFunctionalBenchmark(_QFunctionalBinaryArithmeticBenchmarkBase): def init(self, N, dtype, contig, op_func): - super(QFunctionalAddBenchmarkBase, self).setup(N, dtype, contig) - self.qop = op_func - if self.qop.endswith('_scalar'): - self.q_input_b = 42 - else: - self.q_input_b = self.q_input_a + super(QFunctionalBenchmark, self).setup(N, dtype, contig) + self.inputs = { + "q_input_a": self.q_input_a, + "q_input_b": self.q_input_a, + "scale": self.scale, + "zero_point": self.zero_point + } + self.op_func = op_func + + def forward(self, q_input_a, q_input_b, scale: float, zero_point: int): + return self.op_func(q_input_a, q_input_b, scale=scale, zero_point=zero_point) op_bench.generate_pt_tests_from_op_list(qarithmetic_binary_ops, qarithmetic_binary_configs, - QFunctionalAddBenchmarkBase) + QFunctionalBenchmark) + + +class QFunctionalScalarBenchmark(_QFunctionalBinaryArithmeticBenchmarkBase): + def init(self, N, dtype, contig, op_func): + super(QFunctionalScalarBenchmark, self).setup(N, dtype, contig) + self.inputs = { + "q_input": self.q_input_a, + "scalar_input": 42 + } + self.op_func = op_func + + def forward(self, q_input, scalar_input: int): + return self.op_func(q_input, scalar_input) + + +op_bench.generate_pt_tests_from_op_list(qarithmetic_binary_scalar_ops, + qarithmetic_binary_configs, + QFunctionalScalarBenchmark) if __name__ == '__main__': diff --git a/benchmarks/operator_benchmark/pt/qbatchnorm_test.py b/benchmarks/operator_benchmark/pt/qbatchnorm_test.py index f729f79dcce7..b7d591096a8d 100644 --- a/benchmarks/operator_benchmark/pt/qbatchnorm_test.py +++ b/benchmarks/operator_benchmark/pt/qbatchnorm_test.py @@ -23,15 +23,17 @@ def init(self, M, N, K, device, dtype): self._init(M, N, K, device) x_scale = 0.1 x_zero_point = 0 - self.q_input_one = torch.quantize_per_tensor( - self.input_one, scale=x_scale, zero_point=x_zero_point, dtype=dtype) - self.mean = torch.rand(N) - self.var = torch.rand(N) - self.weight = torch.rand(N) - self.bias = torch.rand(N) - self.eps = 1e-5 - self.Y_scale = 0.1 - self.Y_zero_point = 0 + self.inputs = { + "q_input_one": torch.quantize_per_tensor( + self.input_one, scale=x_scale, zero_point=x_zero_point, dtype=dtype), + "mean": torch.rand(N), + "var": torch.rand(N), + "weight": torch.rand(N), + "bias": torch.rand(N), + "eps": 1e-5, + "Y_scale": 0.1, + "Y_zero_point": 0 + } def _init(self, M, N, K, device): pass @@ -45,10 +47,20 @@ def _init(self, M, N, K, device): self.set_module_name("QBatchNorm1d") self.input_one = torch.rand(M, N, K, device=device, requires_grad=self.auto_set()) - def forward(self): + def forward( + self, + q_input_one, + weight, + bias, + mean, + var, + eps: float, + Y_scale: float, + Y_zero_point: int + ): return torch.ops.quantized.batch_norm1d( - self.q_input_one, self.weight, self.bias, self.mean, self.var, self.eps, - self.Y_scale, self.Y_zero_point) + q_input_one, weight, bias, mean, var, eps, + Y_scale, Y_zero_point) class QBatchNorm2dBenchmark(QBatchNormBenchmark): @@ -58,10 +70,20 @@ def _init(self, M, N, K, device): # add a 1 as the last dimension self.input_one = torch.rand(M, N, K, 1, device=device, requires_grad=self.auto_set()) - def forward(self): + def forward( + self, + q_input_one, + weight, + bias, + mean, + var, + eps: float, + Y_scale: float, + Y_zero_point: int + ): return torch.ops.quantized.batch_norm2d( - self.q_input_one, self.weight, self.bias, self.mean, self.var, self.eps, - self.Y_scale, self.Y_zero_point) + q_input_one, weight, bias, mean, var, eps, + Y_scale, Y_zero_point) op_bench.generate_pt_test(batchnorm_configs_short, QBatchNorm1dBenchmark) diff --git a/benchmarks/operator_benchmark/pt/qcat_test.py b/benchmarks/operator_benchmark/pt/qcat_test.py index 77a66f53d2f0..32dd32e43adf 100644 --- a/benchmarks/operator_benchmark/pt/qcat_test.py +++ b/benchmarks/operator_benchmark/pt/qcat_test.py @@ -2,6 +2,7 @@ import torch import torch.nn.quantized as nnq +from typing import List """Microbenchmarks for quantized Cat operator""" @@ -53,11 +54,14 @@ def init(self, M, N, K, L, dim, contig, dtype): elif contig == 'none': self.input = (q_input_non_contig, q_input_non_contig) - self.dim = dim + self.inputs = { + "input": self.input, + "dim": dim + } self.set_module_name('qcat') - def forward(self): - return self.qf.cat(self.input, dim=self.dim) + def forward(self, input: List[torch.Tensor], dim: int): + return self.qf.cat(input, dim=dim) op_bench.generate_pt_test(qcat_configs_short + qcat_configs_long, diff --git a/benchmarks/operator_benchmark/pt/qcomparators_test.py b/benchmarks/operator_benchmark/pt/qcomparators_test.py index d86ec20eb65d..9c26f6dee23b 100644 --- a/benchmarks/operator_benchmark/pt/qcomparators_test.py +++ b/benchmarks/operator_benchmark/pt/qcomparators_test.py @@ -34,23 +34,32 @@ def init(self, N, dtype, contig, other_scalar, out_variant, op_func): q_input_a = torch.quantize_per_tensor(f_input, scale=scale, zero_point=zero_point, dtype=dtype) - if other_scalar: - q_input_b = 42 - else: - q_input_b = q_input_a.clone() + q_input_b = q_input_a.clone() if not contig: permute_dims = list(range(f_input.ndim))[::-1] q_input_a = q_input_a.permute(permute_dims) self.qop = op_func - self.args = (q_input_a, q_input_b) - self.kwargs = {} + self.inputs = { + "q_input_a": q_input_a, + "q_input_b": q_input_b, + "out_variant": out_variant, + "other_scalar": other_scalar, + } + + def forward(self, q_input_a, q_input_b, out_variant: bool, other_scalar: bool): if out_variant: - self.kwargs['out'] = torch.tensor([], dtype=torch.bool) + if other_scalar: + return self.qop(q_input_a, 42, out=torch.tensor(True, dtype=torch.bool)) + else: + return self.qop(q_input_a, q_input_b, out=torch.tensor(True, dtype=torch.bool)) + else: + if other_scalar: + return self.qop(q_input_a, 42) + else: + return self.qop(q_input_a, q_input_b) - def forward(self): - return self.qop(*self.args, **self.kwargs) op_bench.generate_pt_tests_from_op_list(qcomparators_ops, diff --git a/benchmarks/operator_benchmark/pt/qconv_test.py b/benchmarks/operator_benchmark/pt/qconv_test.py index 24ca5ff9894e..14e8e143a7ca 100644 --- a/benchmarks/operator_benchmark/pt/qconv_test.py +++ b/benchmarks/operator_benchmark/pt/qconv_test.py @@ -24,16 +24,18 @@ def init(self, IC, OC, kernel, stride, N, L, device): W = torch.randn(OC, IC // G, kernel, dtype=torch.float32) self.qW = torch.quantize_per_tensor(W, scale=self.scale, zero_point=0, dtype=torch.qint8) - self.input = qX + self.inputs = { + "input": qX + } self.qconv1d = nnq.Conv1d(IC, OC, kernel, stride=stride, padding=pad, groups=G) self.qconv1d.set_weight_bias(self.qW, None) - self.qconv1d.scale = torch.tensor([self.scale], dtype=torch.double) - self.qconv1d.zero_point = torch.tensor([self.zero_point], dtype=torch.int) + self.qconv1d.scale = torch.tensor(self.scale, dtype=torch.double) + self.qconv1d.zero_point = torch.tensor(self.zero_point, dtype=torch.int) self.set_module_name("QConv1d") - def forward(self): - return self.qconv1d(self.input) + def forward(self, input): + return self.qconv1d(input) class QConv2dBenchmark(op_bench.TorchBenchmarkBase): @@ -51,16 +53,18 @@ def init(self, IC, OC, kernel, stride, N, H, W, G, pad, device): W = torch.randn(OC, IC // G, kernel, kernel, dtype=torch.float32) self.qW = torch.quantize_per_tensor(W, scale=self.scale, zero_point=0, dtype=torch.qint8) - self.input = qX + self.inputs = { + "input": qX + } self.qconv2d = nnq.Conv2d(IC, OC, kernel, stride=stride, padding=pad, groups=G) self.qconv2d.set_weight_bias(self.qW, None) - self.qconv2d.scale = torch.tensor([self.scale], dtype=torch.double) - self.qconv2d.zero_point = torch.tensor([self.zero_point], dtype=torch.int) + self.qconv2d.scale = torch.tensor(self.scale, dtype=torch.double) + self.qconv2d.zero_point = torch.tensor(self.zero_point, dtype=torch.int) self.set_module_name("QConv2d") - def forward(self): - return self.qconv2d(self.input) + def forward(self, input): + return self.qconv2d(input) op_bench.generate_pt_test(configs.remove_cuda(configs.conv_1d_configs_short + configs.conv_1d_configs_long), QConv1dBenchmark) diff --git a/benchmarks/operator_benchmark/pt/qembedding_bag_lookups_test.py b/benchmarks/operator_benchmark/pt/qembedding_bag_lookups_test.py index 5281c43a0d80..4bd06b027969 100644 --- a/benchmarks/operator_benchmark/pt/qembedding_bag_lookups_test.py +++ b/benchmarks/operator_benchmark/pt/qembedding_bag_lookups_test.py @@ -2,7 +2,7 @@ import operator_benchmark as op_bench import torch import numpy as np - +from typing import Optional from torch.testing._internal.common_quantization import ( lengths_to_offsets @@ -20,7 +20,7 @@ is_pruned_weights=(True, False,), use_32bit_indices=(True, False), use_32bit_offsets=(True, False), - tags=('short',), + tags=['short'], ) @@ -33,11 +33,11 @@ is_pruned_weights=(True, False,), use_32bit_indices=(True, False), use_32bit_offsets=(True, False), - tags=('long',) + tags=['long'] ) -full_configs = embedding_bag_rowwise_offsets_long_configs + embedding_bag_rowwise_offsets_short_configs +full_configs = embedding_bag_rowwise_offsets_short_configs + embedding_bag_rowwise_offsets_long_configs four_bit_rowwise_ops = op_bench.op_list( attrs=( @@ -117,14 +117,37 @@ def init(self, if self.is_pruned_weights: self.prepacked_weights, self.compressed_indices = get_pruned_weights_and_mapping(self.prepacked_weights) + self.inputs = { + "prepacked_weights": self.prepacked_weights, + "indices": self.indices, + "offsets": self.offsets, + "mode": 0, + "per_sample_weights": self.per_sample_weights, + "include_last_offset": self.include_last_offset, + "is_pruned_weights": self.is_pruned_weights, + "compressed_indices": self.compressed_indices + } + self.op_func = op_func - def forward(self): - return self.op_func(self.prepacked_weights, self.indices, self.offsets, - mode=0, per_sample_weights=self.per_sample_weights, - include_last_offset=self.include_last_offset, - pruned_weights=self.is_pruned_weights, - compressed_indices_mapping=self.compressed_indices) + def forward( + self, + prepacked_weights, + indices, + offsets, + mode: int, + per_sample_weights: Optional[torch.Tensor], + include_last_offset: bool, + is_pruned_weights: bool, + compressed_indices: Optional[torch.Tensor] + ): + + return self.op_func(prepacked_weights, indices, offsets, + mode=mode, + per_sample_weights=per_sample_weights, + include_last_offset=include_last_offset, + pruned_weights=is_pruned_weights, + compressed_indices_mapping=compressed_indices) class EmbedddingBagByteRowwiseOffsetsTest(op_bench.TorchBenchmarkBase): @@ -181,11 +204,33 @@ def init(self, if self.is_pruned_weights: self.prepacked_weights, self.compressed_indices = get_pruned_weights_and_mapping(self.prepacked_weights) + self.inputs = { + "prepacked_weights": self.prepacked_weights, + "indices": self.indices, + "offsets": self.offsets, + "mode": 0, + "per_sample_weights": self.per_sample_weights, + "include_last_offset": self.include_last_offset, + "is_pruned_weights": self.is_pruned_weights, + "compressed_indices": self.compressed_indices + } + self.op_func = op_func - def forward(self): - return self.op_func(self.prepacked_weights, self.indices, self.offsets, - mode=0, per_sample_weights=self.per_sample_weights, + def forward( + self, + prepacked_weights, + indices, + offsets, + mode: int, + per_sample_weights: Optional[torch.Tensor], + include_last_offset: bool, + is_pruned_weights: bool, + compressed_indices: Optional[torch.Tensor] + ): + return self.op_func(prepacked_weights, indices, offsets, + mode=0, + per_sample_weights=per_sample_weights, include_last_offset=self.include_last_offset, pruned_weights=self.is_pruned_weights, compressed_indices_mapping=self.compressed_indices) diff --git a/benchmarks/operator_benchmark/pt/qembedding_pack_test.py b/benchmarks/operator_benchmark/pt/qembedding_pack_test.py index e64d4fa1962b..f9a3aaff051a 100644 --- a/benchmarks/operator_benchmark/pt/qembedding_pack_test.py +++ b/benchmarks/operator_benchmark/pt/qembedding_pack_test.py @@ -35,21 +35,25 @@ class EmbeddingBagFloatToFusedBase(op_bench.TorchBenchmarkBase): def init(self, num_embeddings, embedding_dim, op_func): - self.weight = torch.from_numpy((np.random.random_sample(( - num_embeddings, embedding_dim)) + 1).astype(np.float32)) + self.inputs = { + "weight": torch.from_numpy((np.random.random_sample(( + num_embeddings, embedding_dim)) + 1).astype(np.float32)) + } self.op_func = op_func - def forward(self): - return self.op_func(self.weight) + def forward(self, weight): + return self.op_func(weight) class EmbeddingBagFusedToFloatBase(op_bench.TorchBenchmarkBase): def init(self, num_embeddings, embedding_dim, op_func): weight = torch.randn(num_embeddings, embedding_dim + 8, dtype=torch.float) - self.packed_weight = weight.to(torch.uint8) + self.inputs = { + "packed_weight": weight.to(torch.uint8) + } self.op_func = op_func - def forward(self): - return self.op_func(self.packed_weight) + def forward(self, packed_weight): + return self.op_func(packed_weight) op_bench.generate_pt_tests_from_op_list(conversion_ops, diff --git a/benchmarks/operator_benchmark/pt/qembeddingbag_test.py b/benchmarks/operator_benchmark/pt/qembeddingbag_test.py index f145edf4f485..872f8c28fccd 100644 --- a/benchmarks/operator_benchmark/pt/qembeddingbag_test.py +++ b/benchmarks/operator_benchmark/pt/qembeddingbag_test.py @@ -20,10 +20,14 @@ def init(self, embeddingbags, dim, mode, input_size, offset, sparse, include_las self.input = torch.tensor(numpy.random.randint(0, embeddingbags, input_size), device=device).long() offset = torch.LongTensor([offset], device=device) self.offset = torch.cat((offset, torch.tensor([self.input.size(0)], dtype=torch.long)), 0) + self.inputs = { + "input": self.input, + "offset": self.offset + } self.set_module_name('qEmbeddingBag') - def forward(self): - return self.embedding(self.input, self.offset) + def forward(self, input, offset): + return self.embedding(input, offset) op_bench.generate_pt_test(configs.embeddingbag_short_configs, QEmbeddingBagBenchmark) diff --git a/benchmarks/operator_benchmark/pt/qgroupnorm_test.py b/benchmarks/operator_benchmark/pt/qgroupnorm_test.py index 6881bc4c518d..942d6ab6560c 100644 --- a/benchmarks/operator_benchmark/pt/qgroupnorm_test.py +++ b/benchmarks/operator_benchmark/pt/qgroupnorm_test.py @@ -20,23 +20,26 @@ class QGroupNormBenchmark(op_bench.TorchBenchmarkBase): def init(self, dims, num_groups, dtype): X = (torch.rand(*dims) - 0.5) * 256 - self.num_groups = num_groups num_channels = dims[1] scale = 1.0 zero_point = 0 - self.qX = torch.quantize_per_tensor( - X, scale=scale, zero_point=zero_point, dtype=dtype) - self.weight = torch.rand(num_channels, dtype=torch.float) - self.bias = torch.rand(num_channels, dtype=torch.float) - self.eps = 1e-5 - self.Y_scale = 0.1 - self.Y_zero_point = 0 - - def forward(self): + + self.inputs = { + "qX": torch.quantize_per_tensor( + X, scale=scale, zero_point=zero_point, dtype=dtype), + "num_groups": num_groups, + "weight": torch.rand(num_channels, dtype=torch.float), + "bias": torch.rand(num_channels, dtype=torch.float), + "eps": 1e-5, + "Y_scale": 0.1, + "Y_zero_point": 0 + } + + def forward(self, qX, num_groups: int, weight, bias, eps: float, Y_scale: float, Y_zero_point: int): return torch.ops.quantized.group_norm( - self.qX, self.num_groups, weight=self.weight, bias=self.bias, - eps=self.eps, output_scale=self.Y_scale, - output_zero_point=self.Y_zero_point) + qX, num_groups, weight=weight, bias=bias, + eps=eps, output_scale=Y_scale, + output_zero_point=Y_zero_point) op_bench.generate_pt_test(groupnorm_configs_short, QGroupNormBenchmark) diff --git a/benchmarks/operator_benchmark/pt/qinstancenorm_test.py b/benchmarks/operator_benchmark/pt/qinstancenorm_test.py index 5a770728bb5e..df084700fac0 100644 --- a/benchmarks/operator_benchmark/pt/qinstancenorm_test.py +++ b/benchmarks/operator_benchmark/pt/qinstancenorm_test.py @@ -22,19 +22,22 @@ def init(self, dims, dtype): num_channels = dims[1] scale = 1.0 zero_point = 0 - self.qX = torch.quantize_per_tensor( - X, scale=scale, zero_point=zero_point, dtype=dtype) - self.weight = torch.rand(num_channels, dtype=torch.float) - self.bias = torch.rand(num_channels, dtype=torch.float) - self.eps = 1e-5 - self.Y_scale = 0.1 - self.Y_zero_point = 0 - - def forward(self): + + self.inputs = { + "qX": torch.quantize_per_tensor( + X, scale=scale, zero_point=zero_point, dtype=dtype), + "weight": torch.rand(num_channels, dtype=torch.float), + "bias": torch.rand(num_channels, dtype=torch.float), + "eps": 1e-5, + "Y_scale": 0.1, + "Y_zero_point": 0 + } + + def forward(self, qX, weight, bias, eps: float, Y_scale: float, Y_zero_point: int): return torch.ops.quantized.instance_norm( - self.qX, weight=self.weight, bias=self.bias, - eps=self.eps, output_scale=self.Y_scale, - output_zero_point=self.Y_zero_point) + qX, weight=weight, bias=bias, + eps=eps, output_scale=Y_scale, + output_zero_point=Y_zero_point) op_bench.generate_pt_test(instancenorm_configs_short, QInstanceNormBenchmark) diff --git a/benchmarks/operator_benchmark/pt/qinterpolate_test.py b/benchmarks/operator_benchmark/pt/qinterpolate_test.py index a1861f4fe4b9..753154f13598 100644 --- a/benchmarks/operator_benchmark/pt/qinterpolate_test.py +++ b/benchmarks/operator_benchmark/pt/qinterpolate_test.py @@ -44,15 +44,18 @@ def init(self, M, N, K, dtype, mode, scale, contig): dtype=dtype) if not contig: permute_dims = list(range(q_input.ndim))[::-1] - self.q_input_a = self.q_input_a.permute(permute_dims) + self.q_input = self.q_input.permute(permute_dims) - self.mode = mode - self.scale_factor = scale + self.inputs = { + "q_input": self.q_input, + "scale_factor": scale, + "mode": mode + } self.set_module_name('q_interpolate') - def forward(self): - return torch.nn.quantized.functional.interpolate( - self.q_input, scale_factor=self.scale_factor, mode=self.mode) + def forward(self, q_input, scale_factor: float, mode: str): + return torch.nn.functional.interpolate( + q_input, scale_factor=scale_factor, mode=mode) op_bench.generate_pt_test(qinterpolate_short_configs + qinterpolate_long_configs, diff --git a/benchmarks/operator_benchmark/pt/qlayernorm_test.py b/benchmarks/operator_benchmark/pt/qlayernorm_test.py index ee3224c31515..0a145ee015ea 100644 --- a/benchmarks/operator_benchmark/pt/qlayernorm_test.py +++ b/benchmarks/operator_benchmark/pt/qlayernorm_test.py @@ -25,17 +25,21 @@ def init(self, dims, dtype): zero_point = 0 self.qX = torch.quantize_per_tensor( X, scale=scale, zero_point=zero_point, dtype=dtype) - self.weight = torch.rand(*self.qX.size()[1:], dtype=torch.float) - self.bias = torch.rand(*self.qX.size()[1:], dtype=torch.float) - self.eps = 1e-5 - self.Y_scale = 0.1 - self.Y_zero_point = 0 - def forward(self): + self.inputs = { + "qX": self.qX, + "weight": torch.rand(*self.qX.size()[1:], dtype=torch.float), + "bias": torch.rand(*self.qX.size()[1:], dtype=torch.float), + "eps": 1e-5, + "Y_scale": 0.1, + "Y_zero_point": 0 + } + + def forward(self, qX, weight, bias, eps: float, Y_scale: float, Y_zero_point: int): return torch.ops.quantized.layer_norm( - self.qX, self.qX.size()[1:], weight=self.weight, bias=self.bias, - eps=self.eps, output_scale=self.Y_scale, - output_zero_point=self.Y_zero_point) + qX, qX.size()[1:], weight=weight, bias=bias, + eps=eps, output_scale=Y_scale, + output_zero_point=Y_zero_point) op_bench.generate_pt_test(layernorm_configs_short, QLayerNormBenchmark) diff --git a/benchmarks/operator_benchmark/pt/qlinear_test.py b/benchmarks/operator_benchmark/pt/qlinear_test.py index 00b477067876..6e4dd9d97eca 100644 --- a/benchmarks/operator_benchmark/pt/qlinear_test.py +++ b/benchmarks/operator_benchmark/pt/qlinear_test.py @@ -26,21 +26,25 @@ def init(self, N, IN, OUT, linear_under_test): self.qlinear.scale = scale self.qlinear.zero_point = zero_point - def forward(self): + def forward(self, input): # Assume that the `self.input` is set in the child - return self.qlinear(self.input) + return self.qlinear(input) class QLinearBenchmark(_QLinearBenchmarkBase): def init(self, N, IN, OUT, device): super(QLinearBenchmark, self).init(N, IN, OUT, nnq.Linear(IN, OUT)) - self.input = self.qX + self.inputs = { + "input": self.qX + } self.set_module_name("QLinear") class QDynamicLinearBenchmark(_QLinearBenchmarkBase): def init(self, N, IN, OUT, device): super(QDynamicLinearBenchmark, self).init(N, IN, OUT, nnqd.Linear(IN, OUT)) - self.input = self.X + self.inputs = { + "input": self.X + } self.set_module_name("QDynamicLinear") diff --git a/benchmarks/operator_benchmark/pt/qobserver_test.py b/benchmarks/operator_benchmark/pt/qobserver_test.py index 149acd260565..6521773a73ff 100644 --- a/benchmarks/operator_benchmark/pt/qobserver_test.py +++ b/benchmarks/operator_benchmark/pt/qobserver_test.py @@ -104,19 +104,22 @@ class QObserverBenchmark(op_bench.TorchBenchmarkBase): def init(self, C, M, N, dtype, qscheme, op_func, device): - self.f_input = torch.rand(C, M, N, device=device) + self.inputs = { + "f_input": torch.rand(C, M, N, device=device) + } self.op_func = op_func(dtype=dtype, qscheme=qscheme).to(device) - def forward(self): - self.op_func(self.f_input) - self.op_func.calculate_qparams() - return + def forward(self, f_input): + self.op_func(f_input) + return self.op_func.calculate_qparams() + class QObserverBenchmarkCalculateQparams(op_bench.TorchBenchmarkBase): def init(self, C, M, N, dtype, qscheme, op_func, device): self.f_input = torch.rand(C, M, N, device=device) self.q_observer = op_func(dtype=dtype, qscheme=qscheme).to(device) self.q_observer(self.f_input) + self.inputs = {} def forward(self): return self.q_observer.calculate_qparams() diff --git a/benchmarks/operator_benchmark/pt/qpool_test.py b/benchmarks/operator_benchmark/pt/qpool_test.py index d53b4d05db98..b5c40fd4977a 100644 --- a/benchmarks/operator_benchmark/pt/qpool_test.py +++ b/benchmarks/operator_benchmark/pt/qpool_test.py @@ -88,8 +88,12 @@ def setup(self, N, C, H, W, dtype, contig): self.q_input = self.q_input.permute(0, 2, 3, 1).contiguous() self.q_input = self.q_input.permute(0, 3, 1, 2) - def forward(self): - return self.pool_op(self.q_input) + self.inputs = { + "q_input": self.q_input + } + + def forward(self, q_input): + return self.pool_op(q_input) class QMaxPool2dBenchmark(_QPool2dBenchmarkBase): diff --git a/benchmarks/operator_benchmark/pt/qrnn_test.py b/benchmarks/operator_benchmark/pt/qrnn_test.py index 187a8f1a82e0..c6d696b81794 100644 --- a/benchmarks/operator_benchmark/pt/qrnn_test.py +++ b/benchmarks/operator_benchmark/pt/qrnn_test.py @@ -45,20 +45,25 @@ def init(self, I, H, NL, B, D, dtype): {nn.LSTM, nn.Linear}, dtype=dtype)[0] - self.x = torch.randn(sequence_len, # sequence length - batch_size, # batch size - I) # Number of features in X - self.h = torch.randn(NL * (D + 1), # layer_num * dir_num - batch_size, # batch size - H) # hidden size - self.c = torch.randn(NL * (D + 1), # layer_num * dir_num - batch_size, # batch size - H) # hidden size + x = torch.randn(sequence_len, # sequence length + batch_size, # batch size + I) # Number of features in X + h = torch.randn(NL * (D + 1), # layer_num * dir_num + batch_size, # batch size + H) # hidden size + c = torch.randn(NL * (D + 1), # layer_num * dir_num + batch_size, # batch size + H) # hidden size + self.inputs = { + "x": x, + "h": h, + "c": c + } self.set_module_name("QLSTM") - def forward(self): - return self.cell(self.x, (self.h, self.c)) + def forward(self, x, h, c): + return self.cell(x, (h, c))[0] op_bench.generate_pt_test(qrnn_configs, LSTMBenchmark) diff --git a/benchmarks/operator_benchmark/pt/qtensor_method_test.py b/benchmarks/operator_benchmark/pt/qtensor_method_test.py index 4834e3cb166b..50dc59780ab8 100644 --- a/benchmarks/operator_benchmark/pt/qtensor_method_test.py +++ b/benchmarks/operator_benchmark/pt/qtensor_method_test.py @@ -22,16 +22,9 @@ tags=['long'] ) -qmethods_tensor_input_list = op_bench.op_list( - attr_names=['op_name', 'op_func'], - attrs=[ - ['q_copy', 'copy_'], - ], -) - class _QMethodBenchmarkBase(op_bench.TorchBenchmarkBase): - def init(self, M, N, dtype, contig, op_func): + def init(self, M, N, dtype, contig): f_input = torch.rand(M, N) scale = 1.0 zero_point = 0 @@ -41,23 +34,20 @@ def init(self, M, N, dtype, contig, op_func): if not contig: permute_dims = list(range(self.q_input.ndim))[::-1] self.q_input = self.q_input.permute(permute_dims) - self.op_func = op_func - -class QMethodTensorInputBenchmark(_QMethodBenchmarkBase): - def forward(self): - getattr(self.q_input, self.op_func)(self.q_input) + self.inputs = { + "q_input": self.q_input, + } -class QMethodNoInputBenchmark(_QMethodBenchmarkBase): - def forward(self): - getattr(self.q_input, self.op_func)() +class QMethodTensorInputCopyBenchmark(_QMethodBenchmarkBase): + def forward(self, q_input): + return q_input.copy_(q_input) -op_bench.generate_pt_tests_from_op_list( - qmethods_tensor_input_list, +op_bench.generate_pt_test( qmethods_configs_short + qmethods_configs_long, - QMethodTensorInputBenchmark + QMethodTensorInputCopyBenchmark ) if __name__ == "__main__": diff --git a/benchmarks/operator_benchmark/pt/quantization_test.py b/benchmarks/operator_benchmark/pt/quantization_test.py index 4a83a2d1b75c..af09a5fa2523 100644 --- a/benchmarks/operator_benchmark/pt/quantization_test.py +++ b/benchmarks/operator_benchmark/pt/quantization_test.py @@ -56,8 +56,12 @@ def init(self, C, M, N, dtype, mode): self.op = nnq.DeQuantize() self.set_module_name('DequantizePerTensor') - def forward(self): - return self.op(self.input) + self.inputs = { + "input": self.input + } + + def forward(self, input): + return self.op(input) op_bench.generate_pt_test( @@ -98,12 +102,22 @@ def init(self, C, M, N, dtype, axis, mode): if mode == 'D': self.input = self.op(self.input, **self.kwargs) - # Dequantize doesn't take any arguments - self.op = lambda x, **kwargs: x.dequantize() + + def dequant(input, scales, zero_points, axis: int, dtype: int): + return input.dequantize() + self.op = dequant self.set_module_name('DequantizePerChannel') - def forward(self): - return self.op(self.input, **self.kwargs) + self.inputs = { + "input": self.input, + 'scales': torch.tensor([1.0] * channel_len), + 'zero_points': torch.tensor([0] * channel_len), + 'axis': axis, + 'dtype': dtype + } + + def forward(self, input, scales, zero_points, axis: int, dtype: int): + return self.op(input, scales=scales, zero_points=zero_points, axis=axis, dtype=dtype) op_bench.generate_pt_test( @@ -141,12 +155,14 @@ def forward(self): class FakeQuantizeBenchmark(op_bench.TorchBenchmarkBase): r"""Benchmarks fake quantization with default parameters.""" def init(self, N, C, H, W): - self.input = torch.rand(N, C, H, W) + self.inputs = { + "input": torch.rand(N, C, H, W) + } self.op = tq.FakeQuantize() self.set_module_name('FakeQuantize') - def forward(self): - return self.op(self.input) + def forward(self, input): + return self.op(input) op_bench.generate_pt_test( @@ -160,11 +176,37 @@ def forward(self): # scale and zero point. # original_kernel represents the original fake quantize c++ kernel. +def fakeQuantizePerTensorPyModule( + input, scale, zero_point, + quant_min: int, quant_max: int +): + return _LearnableFakeQuantizePerTensorOp.apply(input, scale, zero_point, quant_min, quant_max, 1.0) + +def fakeQuantizePerTensorLearnableKernel( + input, scale, zero_point, + quant_min: int, quant_max: int +): + return torch._fake_quantize_learnable_per_tensor_affine(input, scale, zero_point, quant_min, quant_max) + +def fakeQuantizePerTensorOriginalKernel( + input, scale, zero_point, + quant_min: int, quant_max: int +): + return torch.fake_quantize_per_tensor_affine(input, 1.0, 0, quant_min, quant_max) + +fake_quantize_per_tensor_ops = op_bench.op_list( + attrs=( + ('py_module', fakeQuantizePerTensorPyModule), + ('learnable_kernel', fakeQuantizePerTensorLearnableKernel), + ('original_kernel', fakeQuantizePerTensorOriginalKernel) + ), + attr_names=('op_name', 'op_func'), +) + fake_quantize_operator_configs_short = op_bench.config_list( cross_product_configs={ 'nbits': (4, 8), 'device': ('cpu', 'cuda'), - 'op_type': ('py_module', 'learnable_kernel', 'original_kernel') }, **fake_quantize_configs_short_dict ) @@ -172,87 +214,114 @@ def forward(self): fake_quantize_operator_configs_long = op_bench.cross_product_configs( nbits=(4, 8), device=('cpu', 'cuda'), - op_type=('py_module', 'learnable_kernel', 'original_kernel'), **fake_quantize_configs_long_dict ) -class FakeQuantizePerTensorOpBenchmark(op_bench.TorchBenchmarkBase): +class FakeQuantizePerTensorBaseOpBenchmark(op_bench.TorchBenchmarkBase): r"""Benchmarks 3 different fake quantize per tensor operators.""" - def init(self, N, C, H, W, nbits, device, op_type): + def init(self, N, C, H, W, nbits, device, op_func): self.quant_min = 0 self.quant_max = 2 ** nbits - 1 self.quant_range = 2 ** nbits - self.input = torch.rand(N, C, H, W, dtype=torch.float, device=device) - self.scale = torch.tensor([1.]).to(device) - self.zero_point = torch.tensor([0.]).to(device) - self.input.requires_grad_() - self.scale.requires_grad_() - self.zero_point.requires_grad_() - self.args = [ - self.input, self.scale, self.zero_point, - self.quant_min, self.quant_max - ] - if op_type == 'py_module': - self.op = _LearnableFakeQuantizePerTensorOp.apply - self.args.append(1.) - elif op_type == 'learnable_kernel': - self.op = torch._fake_quantize_learnable_per_tensor_affine - else: - # Replace tensors with float and long types for original per tensor - # fake quantize kernel. - self.args[1], self.args[2] = 1., 0 - self.op = torch.fake_quantize_per_tensor_affine + self.input = torch.rand(N, C, H, W, dtype=torch.float, device=device, requires_grad=self.auto_set()) + self.scale = torch.tensor([1.], requires_grad=self.auto_set()).to(device) + self.zero_point = torch.tensor([0.], requires_grad=self.auto_set()).to(device) + + self.inputs = { + "input": self.input, + "scale": self.scale, + "zero_point": self.zero_point, + "quant_min": self.quant_min, + "quant_max": self.quant_max, + } + self.op_func = op_func - def forward(self): - return self.op(*self.args) + def forward( + self, input, scale, zero_point, + quant_min: int, quant_max: int + ): + return self.op_func(input, scale, zero_point, quant_min, quant_max) -op_bench.generate_pt_test( +op_bench.generate_pt_tests_from_op_list( + fake_quantize_per_tensor_ops, fake_quantize_operator_configs_short + fake_quantize_operator_configs_long, - FakeQuantizePerTensorOpBenchmark + FakeQuantizePerTensorBaseOpBenchmark ) - -op_bench.generate_pt_gradient_test( +op_bench.generate_pt_gradient_tests_from_op_list( + fake_quantize_per_tensor_ops, fake_quantize_operator_configs_short + fake_quantize_operator_configs_long, - FakeQuantizePerTensorOpBenchmark + FakeQuantizePerTensorBaseOpBenchmark +) + +def fakeQuantizePerChannelPyModule( + input, scale, zero_point, axis: int, + quant_min: int, quant_max: int +): + return _LearnableFakeQuantizePerChannelOp.apply(input, scale, zero_point, axis, quant_min, quant_max, 1.0) + +def fakeQuantizePerChannelLearnableKernel( + input, scale, zero_point, axis: int, + quant_min: int, quant_max: int +): + return torch._fake_quantize_learnable_per_channel_affine(input, scale, zero_point, axis, quant_min, quant_max) + +def fakeQuantizePerChannelOriginalKernel( + input, scale, zero_point, axis: int, + quant_min: int, quant_max: int +): + return torch.fake_quantize_per_channel_affine(input, scale, zero_point, axis, quant_min, quant_max) + +fake_quantize_per_channel_ops = op_bench.op_list( + attrs=( + ('py_module', fakeQuantizePerChannelPyModule), + ('learnable_kernel', fakeQuantizePerChannelLearnableKernel), + ('original_kernel', fakeQuantizePerChannelOriginalKernel) + ), + attr_names=('op_name', 'op_func'), ) class FakeQuantizePerChannelOpBenchmark(op_bench.TorchBenchmarkBase): r"""Benchmarks 3 different fake quantize per channel operators.""" - def init(self, N, C, H, W, nbits, device, op_type): + def init(self, N, C, H, W, nbits, device, op_func): self.quant_min = 0 self.quant_max = 2 ** nbits - 1 self.quant_range = 2 ** nbits # Axis is chosen with respect to the number of channels: C. self.axis = 1 - self.input = torch.rand(N, C, H, W, dtype=torch.float, device=device) - self.scale = torch.ones(C, device=device, dtype=torch.float32) - self.zero_point = torch.zeros(C, device=device, dtype=torch.float32) - self.input.requires_grad_() - self.scale.requires_grad_() - self.zero_point.requires_grad_() - self.args = [ - self.input, self.scale, self.zero_point, - self.axis, self.quant_min, self.quant_max - ] - if op_type == 'py_module': - self.op = _LearnableFakeQuantizePerChannelOp.apply - self.args.append(1.) - elif op_type == 'learnable_kernel': - self.op = torch._fake_quantize_learnable_per_channel_affine + self.input = torch.rand(N, C, H, W, dtype=torch.float, device=device, requires_grad=self.auto_set()) + + if op_func.__name__ == 'fakeQuantizePerChannelOriginalKernel': + self.scale = torch.ones(C, device=device, dtype=torch.float32, requires_grad=False) + self.zero_point = torch.zeros(C, device=device, dtype=torch.int64, requires_grad=False) else: - self.args[1] = torch.ones(C, device=device, dtype=torch.float32) - self.args[2] = torch.zeros(C, device=device, dtype=torch.int64) - self.op = torch.fake_quantize_per_channel_affine + self.scale = torch.ones(C, device=device, dtype=torch.float32, requires_grad=self.auto_set()) + self.zero_point = torch.zeros(C, device=device, dtype=torch.float32, requires_grad=self.auto_set()) + + self.inputs = { + "input": self.input, + "scale": self.scale, + "zero_point": self.zero_point, + "axis": self.axis, + "quant_min": self.quant_min, + "quant_max": self.quant_max, + } - def forward(self): - return self.op(*self.args) + self.op_func = op_func -op_bench.generate_pt_test( + def forward( + self, input, scale, zero_point, + axis: int, quant_min: int, quant_max: int + ): + return self.op_func(input, scale, zero_point, axis, quant_min, quant_max) + +op_bench.generate_pt_tests_from_op_list( + fake_quantize_per_channel_ops, fake_quantize_operator_configs_short + fake_quantize_operator_configs_long, FakeQuantizePerChannelOpBenchmark ) -op_bench.generate_pt_gradient_test( +op_bench.generate_pt_gradient_tests_from_op_list( + fake_quantize_per_channel_ops, fake_quantize_operator_configs_short + fake_quantize_operator_configs_long, FakeQuantizePerChannelOpBenchmark ) diff --git a/benchmarks/operator_benchmark/pt/qunary_test.py b/benchmarks/operator_benchmark/pt/qunary_test.py index 4b800857caff..2b3cb34ab30c 100644 --- a/benchmarks/operator_benchmark/pt/qunary_test.py +++ b/benchmarks/operator_benchmark/pt/qunary_test.py @@ -30,13 +30,15 @@ def init(self, M, N, dtype, op_func): f_input = torch.rand(M, N) scale = 1.0 zero_point = 0 - self.q_input = torch.quantize_per_tensor(f_input, scale=scale, + self.inputs = { + "q_input": torch.quantize_per_tensor(f_input, scale=scale, zero_point=zero_point, dtype=dtype) + } self.op_func = op_func - def forward(self): - return self.op_func(self.q_input) + def forward(self, q_input): + return self.op_func(q_input) # TODO: Uncomment the ops whenever they are implemented for quantized tensor. @@ -153,17 +155,19 @@ def forward(self): class QTopkOpBenchmark(op_bench.TorchBenchmarkBase): def init(self, M, N, dtype, k): - self.k = k f_input = torch.rand(M, N) scale = 1.0 zero_point = 0 - self.q_input = torch.quantize_per_tensor(f_input, scale=scale, + self.inputs = { + "q_input": torch.quantize_per_tensor(f_input, scale=scale, zero_point=zero_point, - dtype=dtype) + dtype=dtype), + "k": k + } self.set_module_name('qtopk') - def forward(self): - return torch.topk(self.q_input, self.k) + def forward(self, q_input, k: int): + return torch.topk(q_input, k) op_bench.generate_pt_test(qunary_ops_topk_configs_short + qunary_ops_topk_configs_long, QTopkOpBenchmark) diff --git a/benchmarks/operator_benchmark/pt/remainder_test.py b/benchmarks/operator_benchmark/pt/remainder_test.py index ffb38f785b55..1aa7770d63e1 100644 --- a/benchmarks/operator_benchmark/pt/remainder_test.py +++ b/benchmarks/operator_benchmark/pt/remainder_test.py @@ -47,10 +47,15 @@ def init(self, M, N, K, device, dtype, op_func): # +1 so we don't divide by zero self.divisor = (self.divisor * 40 + 1).to(dtype=dtype) + self.inputs = { + "dividend": self.dividend, + "divisor": self.divisor + } + self.op_func = op_func - def forward(self): - return self.op_func(self.dividend, self.divisor) + def forward(self, dividend, divisor): + return self.op_func(dividend, divisor) op_bench.generate_pt_tests_from_op_list(remainder_ops_list, diff --git a/benchmarks/operator_benchmark/pt/softmax_test.py b/benchmarks/operator_benchmark/pt/softmax_test.py index 65446c5c30ee..237d9001e017 100644 --- a/benchmarks/operator_benchmark/pt/softmax_test.py +++ b/benchmarks/operator_benchmark/pt/softmax_test.py @@ -47,11 +47,13 @@ class SoftmaxBenchmark(op_bench.TorchBenchmarkBase): def init(self, N, C, H, W, device, op_func): - self.input_one = torch.rand(N, C, H, W, device=device) + self.inputs = { + "input": torch.rand(N, C, H, W, device=device) + } self.op_func = op_func() - def forward(self): - return self.op_func(self.input_one) + def forward(self, input): + return self.op_func(input) op_bench.generate_pt_tests_from_op_list(softmax_ops_list, diff --git a/benchmarks/operator_benchmark/pt/split_test.py b/benchmarks/operator_benchmark/pt/split_test.py index f4da9437351e..2972db5d2d1b 100644 --- a/benchmarks/operator_benchmark/pt/split_test.py +++ b/benchmarks/operator_benchmark/pt/split_test.py @@ -30,12 +30,14 @@ class SplitBenchmark(op_bench.TorchBenchmarkBase): def init(self, M, N, parts, device): - self.input_one = torch.rand(M, N, device=device) - self.split_size = int(M * N / parts) + self.inputs = { + "input": torch.rand(M, N, device=device), + "split_size": int(M * N / parts) + } self.set_module_name('split') - def forward(self): - return torch.split(self.input_one, self.split_size) + def forward(self, input, split_size: int): + return torch.split(input, split_size) op_bench.generate_pt_test(split_configs_short + split_configs_long, diff --git a/benchmarks/operator_benchmark/pt/sum_test.py b/benchmarks/operator_benchmark/pt/sum_test.py index 6b7fef83469e..799267dfc7de 100644 --- a/benchmarks/operator_benchmark/pt/sum_test.py +++ b/benchmarks/operator_benchmark/pt/sum_test.py @@ -33,11 +33,14 @@ def init(self, R, V, dim, contiguous, device): else: self.input_tensor = tensor - self.dim = dim + self.inputs = { + "input_tensor": self.input_tensor, + "dim": dim + } self.set_module_name("sum") - def forward(self): - return self.input_tensor.sum(dim=self.dim) + def forward(self, input_tensor, dim: int): + return input_tensor.sum(dim=dim) op_bench.generate_pt_test(sum_configs, SumBenchmark) diff --git a/benchmarks/operator_benchmark/pt/tensor_to_test.py b/benchmarks/operator_benchmark/pt/tensor_to_test.py index 7f4c440c2c39..0afaa3191d4e 100644 --- a/benchmarks/operator_benchmark/pt/tensor_to_test.py +++ b/benchmarks/operator_benchmark/pt/tensor_to_test.py @@ -17,17 +17,21 @@ class FloatToHalfTensorConversionBenchmark(op_bench.TorchBenchmarkBase): def init(self, M, N, device): - self.input = torch.rand(M, N, device=device, requires_grad=False, dtype=torch.float) + self.inputs = { + "input": torch.rand(M, N, device=device, requires_grad=False, dtype=torch.float) + } - def forward(self): - return self.input.to(torch.half) + def forward(self, input): + return input.to(torch.half) class HalfToFloatTensorConversionBenchmark(op_bench.TorchBenchmarkBase): def init(self, M, N, device): - self.input = torch.rand(M, N, device=device, requires_grad=False, dtype=torch.half) + self.inputs = { + "input": torch.rand(M, N, device=device, requires_grad=False, dtype=torch.half) + } - def forward(self): - return self.input.to(torch.float) + def forward(self, input): + return input.to(torch.float) op_bench.generate_pt_test(tensor_conversion_short_configs, FloatToHalfTensorConversionBenchmark) diff --git a/benchmarks/operator_benchmark/pt/unary_test.py b/benchmarks/operator_benchmark/pt/unary_test.py index 1391283b1e10..7fd465d6525d 100644 --- a/benchmarks/operator_benchmark/pt/unary_test.py +++ b/benchmarks/operator_benchmark/pt/unary_test.py @@ -27,12 +27,43 @@ class UnaryOpBenchmark(op_bench.TorchBenchmarkBase): def init(self, M, N, device, op_func): - self.input_one = torch.rand(M, N, device=device) + self.inputs = { + "input": torch.rand(M, N, device=device) + } self.op_func = op_func - def forward(self): - return self.op_func(self.input_one) + def forward(self, input): + return self.op_func(input) +def bernoulli_(input): + return input.bernoulli_() + +def cauchy_(input): + return input.cauchy_() + +def digamma_(input): + return input.digamma_() + +def exponential_(input): + return input.exponential_() + +def normal_(input): + return input.normal_() + +def random_(input): + return input.random_() + +def sign_(input): + return input.sign_() + +def uniform_(input): + return input.uniform_() + +def half_(input): + return input.half() + +def long_(input): + return input.long() unary_ops_list = op_bench.op_list( attr_names=['op_name', 'op_func'], @@ -105,18 +136,18 @@ def forward(self): ['tanh_', torch.tanh_], ['trunc', torch.trunc], ['trunc_', torch.trunc_], - ['unique', torch.unique], + ['unique', torch.functional._return_output], ['zero_', torch.zero_], - ['bernoulli_', lambda t: t.bernoulli_()], - ['cauchy_', lambda t: t.cauchy_()], - ['digamma_', lambda t: t.digamma_()], - ['exponential_', lambda t: t.exponential_()], - ['normal_', lambda t: t.normal_()], - ['random_', lambda t: t.random_()], - ['sign_', lambda t: t.sign_()], - ['uniform_', lambda t: t.uniform_()], - ['half', lambda t: t.half()], - ['long', lambda t: t.long()], + ['bernoulli_', bernoulli_], + ['cauchy_', cauchy_], + ['digamma_', digamma_], + ['exponential_', exponential_], + ['normal_', normal_], + ['random_', random_], + ['sign_', sign_], + ['uniform_', uniform_], + ['half', half_], + ['long', long_], ], ) diff --git a/torch/nn/quantized/functional.py b/torch/nn/quantized/functional.py index ed9fd5a08ccc..7364b3166062 100644 --- a/torch/nn/quantized/functional.py +++ b/torch/nn/quantized/functional.py @@ -394,7 +394,7 @@ def max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1, return torch.nn.functional.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode, return_indices) -def celu(input: Tensor, scale: float, zero_point: int, alpha: Optional[float] = 1.) -> Tensor: +def celu(input: Tensor, scale: float, zero_point: int, alpha: float = 1.) -> Tensor: r"""celu(input, scale, zero_point, alpha=1.) -> Tensor Applies the quantized CELU function element-wise. @@ -411,7 +411,7 @@ def celu(input: Tensor, scale: float, zero_point: int, alpha: Optional[float] = def leaky_relu(input: Tensor, negative_slope: float = 0.01, inplace: bool = False, - scale: float = None, zero_point: int = None): + scale: Optional[float] = None, zero_point: Optional[int] = None): r""" Quantized version of the. leaky_relu(input, negative_slope=0.01, inplace=False, scale, zero_point) -> Tensor From a376d3dd5d360f8df9e09ae10677628778b4eb2a Mon Sep 17 00:00:00 2001 From: Jiakai Liu Date: Thu, 12 Nov 2020 19:14:33 -0800 Subject: [PATCH 84/93] [pytorch] strip out warning message ifdef STRIP_ERROR_MESSAGES (#47827) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47827 Similar to TORCH_CHECK_WITH_MSG, strip messages for TORCH_WARN/TORCH_WARN_ONCE. Test Plan: Imported from OSS Reviewed By: ezyang Differential Revision: D24913586 Pulled By: ljk53 fbshipit-source-id: 00f0f2bf33a48d5d7008b70ff5820623586dfd4e --- c10/util/Exception.h | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/c10/util/Exception.h b/c10/util/Exception.h index 3a80cd1d3fb4..fed17a4cf526 100644 --- a/c10/util/Exception.h +++ b/c10/util/Exception.h @@ -387,18 +387,30 @@ inline std::string if_empty_then(std::string x, std::string y) { // Report a warning to the user. Accepts an arbitrary number of extra // arguments which are concatenated into the warning message using operator<< // +#ifdef STRIP_ERROR_MESSAGES +#define TORCH_WARN(...) \ + ::c10::Warning::warn({__func__, __FILE__, static_cast(__LINE__)}, {}, false) +#else #define TORCH_WARN(...) \ ::c10::Warning::warn({__func__, __FILE__, static_cast(__LINE__)}, ::c10::str(__VA_ARGS__), false) +#endif // Report a warning to the user only once. Accepts an arbitrary number of extra // arguments which are concatenated into the warning message using operator<< // +#ifdef STRIP_ERROR_MESSAGES +#define TORCH_WARN_ONCE(...) \ + C10_UNUSED static const auto C10_ANONYMOUS_VARIABLE(torch_warn_once_) = [&] { \ + ::c10::Warning::warn({__func__, __FILE__, static_cast(__LINE__)}, {}, false); \ + return true; \ + }() +#else #define TORCH_WARN_ONCE(...) \ C10_UNUSED static const auto C10_ANONYMOUS_VARIABLE(torch_warn_once_) = [&] { \ ::c10::Warning::warn({__func__, __FILE__, static_cast(__LINE__)}, ::c10::str(__VA_ARGS__), false); \ return true; \ }() - +#endif // ---------------------------------------------------------------------------- // Deprecated macros From 4f538a2ba48afeb2a2a1f3b6e01b1ec461d4a5ed Mon Sep 17 00:00:00 2001 From: Jiakai Liu Date: Thu, 12 Nov 2020 19:14:33 -0800 Subject: [PATCH 85/93] [pytorch][bot] update mobile op deps (#47825) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47825 Test Plan: Imported from OSS Reviewed By: iseeyuan Differential Revision: D24913587 Pulled By: ljk53 fbshipit-source-id: b6219573c3238fb453d88019197a00c9f9dbabb8 --- tools/code_analyzer/default_op_deps.yaml | 1337 +++++++++++++++------- 1 file changed, 915 insertions(+), 422 deletions(-) diff --git a/tools/code_analyzer/default_op_deps.yaml b/tools/code_analyzer/default_op_deps.yaml index c2adb0dbb807..bba322fec8e6 100644 --- a/tools/code_analyzer/default_op_deps.yaml +++ b/tools/code_analyzer/default_op_deps.yaml @@ -1,20 +1,27 @@ - name: __ROOT__ depends: + - name: aten::_coalesced_ - name: aten::_empty_affine_quantized - name: aten::_empty_per_channel_affine_quantized - name: aten::_indices + - name: aten::_mkldnn_transpose - name: aten::_sparse_coo_tensor_unsafe - name: aten::_values - name: aten::_version - name: aten::add - name: aten::add_ + - name: aten::addmm_ - name: aten::any + - name: aten::as_strided - name: aten::as_strided_ - name: aten::cat - name: aten::chunk + - name: aten::clamp_max + - name: aten::clamp_min - name: aten::clone - name: aten::contiguous - name: aten::copy_ + - name: aten::dense_dim - name: aten::dequantize - name: aten::detach - name: aten::empty @@ -23,7 +30,6 @@ - name: aten::eq - name: aten::equal - name: aten::expand - - name: aten::fill_ - name: aten::is_complex - name: aten::is_floating_point - name: aten::is_leaf @@ -33,8 +39,8 @@ - name: aten::lt - name: aten::mm - name: aten::mul + - name: aten::mul_ - name: aten::narrow - - name: aten::ones - name: aten::ones_like - name: aten::output_nr - name: aten::q_per_channel_axis @@ -52,19 +58,36 @@ - name: aten::set_ - name: aten::set_data - name: aten::size + - name: aten::sparse_dim - name: aten::stride - name: aten::sub - name: aten::sum - name: aten::t - name: aten::to + - name: aten::transpose + - name: aten::transpose_ - name: aten::view - name: aten::zero_ - name: aten::zeros - name: aten::zeros_like - name: _quantized::add depends: + - name: aten::_empty_affine_quantized + - name: aten::as_strided_ + - name: aten::contiguous + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided - name: aten::eq - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::size + - name: aten::to - name: _quantized::conv2d depends: - name: aten::eq @@ -98,10 +121,53 @@ - name: aten::is_nonzero - name: aten::squeeze_ - name: aten::unsqueeze +- name: _quantized::conv_transpose1d_prepack + depends: + - name: aten::_empty_affine_quantized + - name: aten::contiguous + - name: aten::empty + - name: aten::eq + - name: aten::equal + - name: aten::is_nonzero + - name: aten::item + - name: aten::mul + - name: aten::q_per_channel_scales + - name: aten::q_per_channel_zero_points + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::select + - name: aten::size + - name: aten::to + - name: aten::unsqueeze + - name: aten::zeros - name: _quantized::conv_transpose2d depends: - name: aten::eq - name: aten::is_nonzero +- name: _quantized::conv_transpose2d_prepack + depends: + - name: aten::_empty_affine_quantized + - name: aten::contiguous + - name: aten::empty + - name: aten::eq + - name: aten::equal + - name: aten::is_nonzero + - name: aten::item + - name: aten::mul + - name: aten::q_per_channel_scales + - name: aten::q_per_channel_zero_points + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::select + - name: aten::size + - name: aten::to + - name: aten::zeros +- name: _quantized::conv_transpose3d_prepack + depends: + - name: aten::eq + - name: aten::is_nonzero - name: _quantized::linear depends: - name: aten::eq @@ -112,8 +178,28 @@ - name: aten::is_nonzero - name: _quantized::linear_prepack depends: + - name: aten::_empty_affine_quantized + - name: aten::contiguous + - name: aten::copy_ + - name: aten::empty - name: aten::eq + - name: aten::equal - name: aten::is_nonzero + - name: aten::item + - name: aten::max + - name: aten::min + - name: aten::mul + - name: aten::q_per_channel_scales + - name: aten::q_per_channel_zero_points + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::quantize_per_tensor + - name: aten::resize_ + - name: aten::select + - name: aten::size + - name: aten::to + - name: aten::zeros - name: _quantized::linear_prepack_fp16 depends: - name: aten::eq @@ -140,11 +226,6 @@ - name: aten::eq - name: aten::is_nonzero - name: aten::leaky_relu -- name: aten::Int - depends: - - name: aten::eq - - name: aten::is_nonzero - - name: aten::item - name: aten::__and__ depends: - name: aten::bitwise_and @@ -163,11 +244,7 @@ - name: aten::empty_like - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - - name: aten::is_complex - name: aten::is_nonzero - - name: aten::mul - - name: aten::ones - name: aten::resize_ - name: aten::resize_as_ - name: aten::scalar_tensor @@ -185,11 +262,7 @@ - name: aten::empty_like - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - - name: aten::is_complex - name: aten::is_nonzero - - name: aten::mul - - name: aten::ones - name: aten::resize_ - name: aten::resize_as_ - name: aten::scalar_tensor @@ -207,11 +280,7 @@ - name: aten::empty_like - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - - name: aten::is_complex - name: aten::is_nonzero - - name: aten::mul - - name: aten::ones - name: aten::resize_ - name: aten::resize_as_ - name: aten::scalar_tensor @@ -229,11 +298,7 @@ - name: aten::empty_like - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - - name: aten::is_complex - name: aten::is_nonzero - - name: aten::mul - - name: aten::ones - name: aten::resize_ - name: aten::resize_as_ - name: aten::scalar_tensor @@ -266,57 +331,37 @@ depends: - name: aten::eq - name: aten::is_nonzero -- name: aten::_addmv_impl_ - depends: - - name: aten::contiguous - - name: aten::eq - - name: aten::is_nonzero - - name: aten::size - - name: aten::stride -- name: aten::_addr +- name: aten::_add_relu depends: - - name: aten::_copy_from - name: aten::as_strided_ - name: aten::copy_ - - name: aten::copy_sparse_to_sparse_ - name: aten::empty - name: aten::empty_like - name: aten::empty_strided - name: aten::eq - - name: aten::is_complex - name: aten::is_nonzero - - name: aten::q_scale - - name: aten::q_zero_point - - name: aten::qscheme - name: aten::resize_ - name: aten::resize_as_ - - name: aten::set_quantizer_ - - name: aten::size - - name: aten::stride - name: aten::to - - name: aten::zero_ -- name: aten::_addr_ +- name: aten::_add_relu_ depends: - - name: aten::_copy_from - name: aten::as_strided_ - name: aten::copy_ - - name: aten::copy_sparse_to_sparse_ - name: aten::empty - name: aten::empty_like - name: aten::empty_strided - name: aten::eq - - name: aten::is_complex - name: aten::is_nonzero - - name: aten::q_scale - - name: aten::q_zero_point - - name: aten::qscheme - name: aten::resize_ - name: aten::resize_as_ - - name: aten::set_quantizer_ + - name: aten::to +- name: aten::_addmv_impl_ + depends: + - name: aten::contiguous + - name: aten::eq + - name: aten::is_nonzero - name: aten::size - name: aten::stride - - name: aten::to - - name: aten::zero_ - name: aten::_aminmax depends: - name: aten::as_strided_ @@ -336,7 +381,7 @@ - name: aten::stride - name: aten::to - name: aten::unsqueeze_ -- name: aten::_amp_non_finite_check_and_unscale_ +- name: aten::_amp_foreach_non_finite_check_and_unscale_ depends: - name: aten::eq - name: aten::is_nonzero @@ -344,6 +389,10 @@ depends: - name: aten::eq - name: aten::is_nonzero +- name: aten::_backward + depends: + - name: aten::eq + - name: aten::is_nonzero - name: aten::_baddbmm_mkl_ depends: - name: aten::eq @@ -690,9 +739,16 @@ - name: aten::size - name: aten::_dirichlet_grad depends: + - name: aten::as_strided_ + - name: aten::copy_ - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided - name: aten::eq - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to - name: aten::_embedding_bag depends: - name: aten::contiguous @@ -896,6 +952,7 @@ - name: aten::unsqueeze - name: aten::_fft_with_size depends: + - name: aten::_fft_with_size - name: aten::eq - name: aten::is_nonzero - name: aten::_foreach_add @@ -948,6 +1005,16 @@ - name: aten::eq - name: aten::exp_ - name: aten::is_nonzero +- name: aten::_foreach_maximum + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::maximum +- name: aten::_foreach_minimum + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::minimum - name: aten::_foreach_mul depends: - name: aten::eq @@ -1320,10 +1387,18 @@ - name: aten::to - name: aten::_sample_dirichlet depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided - name: aten::eq - name: aten::expand - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ - name: aten::sum + - name: aten::to - name: aten::zeros - name: aten::_saturate_weight_to_fp16 depends: @@ -1497,14 +1572,29 @@ - name: aten::is_nonzero - name: aten::_standard_gamma depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided - name: aten::eq - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to - name: aten::zeros - name: aten::_standard_gamma_grad depends: + - name: aten::as_strided_ + - name: aten::copy_ - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided - name: aten::eq - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to - name: aten::_std depends: - name: aten::eq @@ -1557,6 +1647,10 @@ - name: aten::is_nonzero - name: aten::mul - name: aten::sub +- name: aten::_test_string_default + depends: + - name: aten::eq + - name: aten::is_nonzero - name: aten::_thnn_differentiable_gru_cell_backward depends: - name: aten::add @@ -1737,6 +1831,7 @@ depends: - name: aten::abs - name: aten::eq + - name: aten::is_complex - name: aten::is_nonzero - name: aten::absolute depends: @@ -1752,6 +1847,7 @@ depends: - name: aten::acos - name: aten::as_strided_ + - name: aten::copy_ - name: aten::empty - name: aten::empty_like - name: aten::empty_strided @@ -1760,6 +1856,7 @@ - name: aten::is_leaf - name: aten::is_nonzero - name: aten::resize_ + - name: aten::resize_as_ - name: aten::to - name: aten::acos_ depends: @@ -1768,7 +1865,6 @@ - name: aten::is_nonzero - name: aten::acosh depends: - - name: aten::acosh - name: aten::as_strided_ - name: aten::copy_ - name: aten::empty @@ -1874,11 +1970,7 @@ - name: aten::empty_meta - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - - name: aten::is_complex - name: aten::is_nonzero - - name: aten::mul - - name: aten::ones - name: aten::permute - name: aten::resize_ - name: aten::resize_as_ @@ -1894,39 +1986,11 @@ - name: aten::empty_like - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - - name: aten::is_complex - name: aten::is_nonzero - - name: aten::mul - - name: aten::ones - name: aten::resize_ - name: aten::resize_as_ - name: aten::scalar_tensor - name: aten::to -- name: aten::_add_relu - depends: - - name: aten::as_strided_ - - name: aten::copy_ - - name: aten::empty - - name: aten::empty_like - - name: aten::empty_strided - - name: aten::eq - - name: aten::is_nonzero - - name: aten::resize_ - - name: aten::resize_as_ - - name: aten::to -- name: aten::_add_relu_ - depends: - - name: aten::as_strided_ - - name: aten::copy_ - - name: aten::empty - - name: aten::empty_like - - name: aten::empty_strided - - name: aten::eq - - name: aten::is_nonzero - - name: aten::resize_ - - name: aten::resize_as_ - - name: aten::to - name: aten::addbmm depends: - name: aten::addbmm @@ -2082,18 +2146,20 @@ - name: aten::zero_ - name: aten::addr depends: - - name: aten::_addr + - name: aten::add - name: aten::addr + - name: aten::copy_ - name: aten::eq - - name: aten::expand - name: aten::is_floating_point - name: aten::is_leaf - name: aten::is_nonzero - - name: aten::size + - name: aten::mul + - name: aten::outer + - name: aten::resize_ - name: aten::to - name: aten::addr_ depends: - - name: aten::_addr_ + - name: aten::addr - name: aten::eq - name: aten::is_nonzero - name: aten::affine_grid_generator @@ -2362,8 +2428,10 @@ - name: aten::sort - name: aten::as_strided depends: + - name: aten::as_strided - name: aten::eq - name: aten::is_nonzero + - name: aten::permute - name: aten::as_strided_ depends: - name: aten::eq @@ -2389,7 +2457,6 @@ - name: aten::asinh depends: - name: aten::as_strided_ - - name: aten::asinh - name: aten::copy_ - name: aten::empty - name: aten::empty_like @@ -2408,12 +2475,14 @@ depends: - name: aten::as_strided_ - name: aten::atan + - name: aten::copy_ - name: aten::empty - name: aten::empty_like - name: aten::empty_strided - name: aten::eq - name: aten::is_nonzero - name: aten::resize_ + - name: aten::resize_as_ - name: aten::to - name: aten::atan2 depends: @@ -2450,7 +2519,6 @@ - name: aten::atanh depends: - name: aten::as_strided_ - - name: aten::atanh - name: aten::copy_ - name: aten::empty - name: aten::empty_like @@ -2529,10 +2597,6 @@ - name: aten::size - name: aten::zero_ - name: aten::zeros_like -- name: aten::backward - depends: - - name: aten::eq - - name: aten::is_nonzero - name: aten::baddbmm depends: - name: aten::addmm_ @@ -2746,8 +2810,16 @@ - name: aten::zero_ - name: aten::binomial depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided - name: aten::eq - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to - name: aten::zeros - name: aten::bitwise_and depends: @@ -2758,11 +2830,7 @@ - name: aten::empty_like - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - - name: aten::is_complex - name: aten::is_nonzero - - name: aten::mul - - name: aten::ones - name: aten::resize_ - name: aten::resize_as_ - name: aten::scalar_tensor @@ -2799,11 +2867,7 @@ - name: aten::empty_like - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - - name: aten::is_complex - name: aten::is_nonzero - - name: aten::mul - - name: aten::ones - name: aten::resize_ - name: aten::resize_as_ - name: aten::scalar_tensor @@ -2822,11 +2886,7 @@ - name: aten::empty_like - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - - name: aten::is_complex - name: aten::is_nonzero - - name: aten::mul - - name: aten::ones - name: aten::resize_ - name: aten::resize_as_ - name: aten::scalar_tensor @@ -2892,20 +2952,11 @@ - name: aten::is_nonzero - name: aten::bucketize depends: - - name: aten::as_strided_ - name: aten::contiguous - - name: aten::copy_ - name: aten::empty - - name: aten::empty_like - - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - - name: aten::is_complex - name: aten::is_nonzero - - name: aten::mul - - name: aten::ones - name: aten::resize_ - - name: aten::resize_as_ - name: aten::scalar_tensor - name: aten::to - name: aten::can_cast @@ -3075,6 +3126,14 @@ - name: aten::is_nonzero - name: aten::resize_as_ - name: aten::size +- name: aten::choose_qparams_optimized + depends: + - name: aten::contiguous + - name: aten::empty + - name: aten::eq + - name: aten::fill_ + - name: aten::is_nonzero + - name: aten::select - name: aten::chunk depends: - name: aten::chunk @@ -3164,6 +3223,7 @@ - name: aten::_empty_affine_quantized - name: aten::_empty_per_channel_affine_quantized - name: aten::as_strided_ + - name: aten::clone - name: aten::copy_ - name: aten::copy_sparse_to_sparse_ - name: aten::empty @@ -3172,6 +3232,7 @@ - name: aten::eq - name: aten::is_complex - name: aten::is_nonzero + - name: aten::permute - name: aten::q_per_channel_axis - name: aten::q_per_channel_scales - name: aten::q_per_channel_zero_points @@ -3208,6 +3269,12 @@ - name: aten::select - name: aten::size - name: aten::zero_ +- name: aten::column_stack + depends: + - name: aten::eq + - name: aten::hstack + - name: aten::is_nonzero + - name: aten::reshape - name: aten::combinations depends: - name: aten::arange @@ -3263,10 +3330,12 @@ - name: aten::size - name: aten::contiguous depends: + - name: aten::contiguous - name: aten::copy_ - name: aten::empty_like - name: aten::eq - name: aten::is_nonzero + - name: aten::permute - name: aten::conv1d depends: - name: aten::conv1d @@ -3386,21 +3455,40 @@ - name: aten::size - name: aten::stride - name: aten::to -- name: aten::copy_imag +- name: aten::copy_sparse_to_sparse_ depends: - name: aten::eq - name: aten::is_nonzero -- name: aten::copy_real +- name: aten::copysign depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided - name: aten::eq - name: aten::is_nonzero -- name: aten::copy_sparse_to_sparse_ + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::scalar_tensor + - name: aten::to +- name: aten::copysign_ depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided - name: aten::eq - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::scalar_tensor + - name: aten::to - name: aten::cos depends: - name: aten::as_strided_ + - name: aten::copy_ - name: aten::cos - name: aten::empty - name: aten::empty_like @@ -3408,6 +3496,7 @@ - name: aten::eq - name: aten::is_nonzero - name: aten::resize_ + - name: aten::resize_as_ - name: aten::to - name: aten::cos_ depends: @@ -3648,22 +3737,13 @@ - name: aten::is_nonzero - name: aten::deg2rad depends: - - name: aten::as_strided_ - - name: aten::copy_ - name: aten::deg2rad - name: aten::empty - - name: aten::empty_like - - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - name: aten::is_complex - name: aten::is_nonzero - name: aten::mul - - name: aten::ones - - name: aten::resize_ - - name: aten::resize_as_ - name: aten::scalar_tensor - - name: aten::to - name: aten::deg2rad_ depends: - name: aten::deg2rad @@ -3684,7 +3764,6 @@ - name: aten::add_ - name: aten::all - name: aten::arange - - name: aten::contiguous - name: aten::diagonal - name: aten::eq - name: aten::fmod_ @@ -3799,11 +3878,7 @@ - name: aten::empty_like - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - - name: aten::is_complex - name: aten::is_nonzero - - name: aten::mul - - name: aten::ones - name: aten::permute - name: aten::resize_ - name: aten::resize_as_ @@ -3819,15 +3894,21 @@ - name: aten::empty_like - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - - name: aten::is_complex - name: aten::is_nonzero - - name: aten::mul - - name: aten::ones - name: aten::resize_ - name: aten::resize_as_ - name: aten::scalar_tensor - name: aten::to +- name: aten::divide + depends: + - name: aten::div + - name: aten::eq + - name: aten::is_nonzero +- name: aten::divide_ + depends: + - name: aten::div_ + - name: aten::eq + - name: aten::is_nonzero - name: aten::dot depends: - name: aten::dot @@ -3897,15 +3978,20 @@ depends: - name: aten::bmm - name: aten::diagonal + - name: aten::dot - name: aten::eq + - name: aten::flatten - name: aten::is_nonzero + - name: aten::movedim - name: aten::mul - name: aten::permute - name: aten::reshape - name: aten::size + - name: aten::squeeze - name: aten::sum - name: aten::unsqueeze - name: aten::view + - name: aten::zeros - name: aten::elu depends: - name: aten::as_strided_ @@ -3940,6 +4026,7 @@ - name: aten::eq - name: aten::index_select - name: aten::is_nonzero + - name: aten::masked_fill_ - name: aten::reshape - name: aten::view - name: aten::embedding_backward @@ -3984,6 +4071,7 @@ - name: aten::ne - name: aten::reshape - name: aten::size + - name: aten::to - name: aten::empty depends: - name: aten::empty @@ -4044,33 +4132,17 @@ - name: aten::empty_like - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - - name: aten::is_complex - name: aten::is_nonzero - name: aten::item - - name: aten::mul - - name: aten::ones - name: aten::resize_ - name: aten::resize_as_ - name: aten::scalar_tensor - name: aten::to - name: aten::eq_ depends: - - name: aten::as_strided_ - - name: aten::copy_ - - name: aten::empty - - name: aten::empty_like - - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - - name: aten::is_complex - name: aten::is_nonzero - - name: aten::mul - - name: aten::ones - - name: aten::resize_ - - name: aten::resize_as_ - name: aten::scalar_tensor - - name: aten::to - name: aten::equal depends: - name: aten::as_strided_ @@ -4363,10 +4435,261 @@ - name: aten::unsqueeze - name: aten::fft_fft depends: + - name: aten::_fft_with_size + - name: aten::conj + - name: aten::constant_pad_nd + - name: aten::eq + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::reshape + - name: aten::size + - name: aten::slice + - name: aten::squeeze + - name: aten::to + - name: aten::transpose + - name: aten::unsqueeze + - name: aten::view_as_complex + - name: aten::view_as_real +- name: aten::fft_fft2 + depends: + - name: aten::_fft_with_size + - name: aten::constant_pad_nd + - name: aten::eq + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::movedim + - name: aten::reshape + - name: aten::size + - name: aten::slice + - name: aten::squeeze + - name: aten::to + - name: aten::unsqueeze + - name: aten::view_as_complex + - name: aten::view_as_real +- name: aten::fft_fftfreq + depends: + - name: aten::arange + - name: aten::empty + - name: aten::eq + - name: aten::fft_fftfreq + - name: aten::is_nonzero + - name: aten::mul_ + - name: aten::slice +- name: aten::fft_fftn + depends: + - name: aten::_fft_with_size + - name: aten::constant_pad_nd + - name: aten::eq + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::movedim + - name: aten::reshape + - name: aten::size + - name: aten::slice + - name: aten::squeeze + - name: aten::to + - name: aten::unsqueeze + - name: aten::view_as_complex + - name: aten::view_as_real +- name: aten::fft_fftshift + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::roll +- name: aten::fft_hfft + depends: + - name: aten::_fft_with_size + - name: aten::conj + - name: aten::constant_pad_nd + - name: aten::eq + - name: aten::is_nonzero + - name: aten::reshape + - name: aten::size + - name: aten::slice + - name: aten::squeeze + - name: aten::to + - name: aten::transpose + - name: aten::unsqueeze + - name: aten::view_as_real +- name: aten::fft_ifft + depends: + - name: aten::_fft_with_size + - name: aten::conj + - name: aten::constant_pad_nd + - name: aten::eq + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::reshape + - name: aten::size + - name: aten::slice + - name: aten::squeeze + - name: aten::to + - name: aten::transpose + - name: aten::unsqueeze + - name: aten::view_as_complex + - name: aten::view_as_real +- name: aten::fft_ifft2 + depends: + - name: aten::_fft_with_size + - name: aten::constant_pad_nd + - name: aten::eq + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::movedim + - name: aten::reshape + - name: aten::size + - name: aten::slice + - name: aten::squeeze + - name: aten::to + - name: aten::unsqueeze + - name: aten::view_as_complex + - name: aten::view_as_real +- name: aten::fft_ifftn + depends: + - name: aten::_fft_with_size + - name: aten::constant_pad_nd + - name: aten::eq + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::movedim + - name: aten::reshape + - name: aten::size + - name: aten::slice + - name: aten::squeeze + - name: aten::to + - name: aten::unsqueeze + - name: aten::view_as_complex + - name: aten::view_as_real +- name: aten::fft_ifftshift + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::roll +- name: aten::fft_ihfft + depends: + - name: aten::_fft_with_size + - name: aten::conj + - name: aten::constant_pad_nd + - name: aten::eq + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::reshape + - name: aten::size + - name: aten::slice + - name: aten::squeeze + - name: aten::to + - name: aten::transpose + - name: aten::unsqueeze + - name: aten::view_as_complex +- name: aten::fft_irfft + depends: + - name: aten::_fft_with_size + - name: aten::conj + - name: aten::constant_pad_nd + - name: aten::eq + - name: aten::is_nonzero + - name: aten::reshape + - name: aten::size + - name: aten::slice + - name: aten::squeeze + - name: aten::to + - name: aten::transpose + - name: aten::unsqueeze + - name: aten::view_as_real +- name: aten::fft_irfft2 + depends: + - name: aten::_fft_with_size + - name: aten::conj + - name: aten::constant_pad_nd + - name: aten::eq + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::movedim + - name: aten::reshape + - name: aten::size + - name: aten::slice + - name: aten::squeeze + - name: aten::to + - name: aten::transpose + - name: aten::unsqueeze + - name: aten::view_as_complex + - name: aten::view_as_real +- name: aten::fft_irfftn + depends: + - name: aten::_fft_with_size + - name: aten::conj + - name: aten::constant_pad_nd + - name: aten::eq + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::movedim + - name: aten::reshape + - name: aten::size + - name: aten::slice + - name: aten::squeeze + - name: aten::to + - name: aten::transpose + - name: aten::unsqueeze + - name: aten::view_as_complex + - name: aten::view_as_real +- name: aten::fft_rfft + depends: + - name: aten::_fft_with_size + - name: aten::conj + - name: aten::constant_pad_nd + - name: aten::eq + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::reshape + - name: aten::size + - name: aten::slice + - name: aten::squeeze + - name: aten::to + - name: aten::transpose + - name: aten::unsqueeze + - name: aten::view_as_complex +- name: aten::fft_rfft2 + depends: + - name: aten::_fft_with_size + - name: aten::conj + - name: aten::constant_pad_nd - name: aten::eq - - name: aten::fft - name: aten::is_complex - name: aten::is_nonzero + - name: aten::movedim + - name: aten::reshape + - name: aten::size + - name: aten::slice + - name: aten::squeeze + - name: aten::to + - name: aten::transpose + - name: aten::unsqueeze + - name: aten::view_as_complex + - name: aten::view_as_real +- name: aten::fft_rfftfreq + depends: + - name: aten::arange + - name: aten::empty + - name: aten::eq + - name: aten::fft_rfftfreq + - name: aten::is_nonzero + - name: aten::mul_ +- name: aten::fft_rfftn + depends: + - name: aten::_fft_with_size + - name: aten::conj + - name: aten::constant_pad_nd + - name: aten::eq + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::movedim + - name: aten::reshape + - name: aten::size + - name: aten::slice + - name: aten::squeeze + - name: aten::to + - name: aten::transpose + - name: aten::unsqueeze - name: aten::view_as_complex - name: aten::view_as_real - name: aten::fill_ @@ -4461,13 +4784,9 @@ - name: aten::empty_like - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - name: aten::floor_divide - - name: aten::is_complex - name: aten::is_floating_point - name: aten::is_nonzero - - name: aten::mul - - name: aten::ones - name: aten::resize_ - name: aten::resize_as_ - name: aten::scalar_tensor @@ -4481,13 +4800,9 @@ - name: aten::empty_like - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - name: aten::floor_divide - - name: aten::is_complex - name: aten::is_floating_point - name: aten::is_nonzero - - name: aten::mul - - name: aten::ones - name: aten::resize_ - name: aten::resize_as_ - name: aten::scalar_tensor @@ -4566,6 +4881,8 @@ - name: aten::frobenius_norm depends: - name: aten::conj + - name: aten::copy_ + - name: aten::empty - name: aten::eq - name: aten::frobenius_norm - name: aten::is_complex @@ -4575,6 +4892,7 @@ - name: aten::mul - name: aten::norm - name: aten::real + - name: aten::resize_ - name: aten::sqrt - name: aten::sum - name: aten::to @@ -4643,35 +4961,19 @@ - name: aten::empty_like - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - name: aten::ge - - name: aten::is_complex - name: aten::is_nonzero - name: aten::item - - name: aten::mul - - name: aten::ones - name: aten::resize_ - name: aten::resize_as_ - name: aten::scalar_tensor - name: aten::to - name: aten::ge_ depends: - - name: aten::as_strided_ - - name: aten::copy_ - - name: aten::empty - - name: aten::empty_like - - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - name: aten::ge - - name: aten::is_complex - name: aten::is_nonzero - - name: aten::mul - - name: aten::ones - - name: aten::resize_ - - name: aten::resize_as_ - name: aten::scalar_tensor - - name: aten::to - name: aten::gelu depends: - name: aten::_empty_affine_quantized @@ -4759,13 +5061,9 @@ - name: aten::to - name: aten::ger depends: - - name: aten::_addr - - name: aten::empty - name: aten::eq - - name: aten::ger - name: aten::is_nonzero - - name: aten::resize_ - - name: aten::size + - name: aten::outer - name: aten::get_gradients depends: - name: aten::eq @@ -4927,35 +5225,19 @@ - name: aten::empty_like - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - name: aten::gt - - name: aten::is_complex - name: aten::is_nonzero - name: aten::item - - name: aten::mul - - name: aten::ones - name: aten::resize_ - name: aten::resize_as_ - name: aten::scalar_tensor - name: aten::to - name: aten::gt_ depends: - - name: aten::as_strided_ - - name: aten::copy_ - - name: aten::empty - - name: aten::empty_like - - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - name: aten::gt - - name: aten::is_complex - name: aten::is_nonzero - - name: aten::mul - - name: aten::ones - - name: aten::resize_ - - name: aten::resize_as_ - name: aten::scalar_tensor - - name: aten::to - name: aten::hamming_window depends: - name: aten::add_ @@ -5225,6 +5507,23 @@ - name: aten::size - name: aten::squeeze - name: aten::unsqueeze +- name: aten::igamma + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::igamma_ + depends: + - name: aten::eq + - name: aten::igamma + - name: aten::is_nonzero - name: aten::im2col depends: - name: aten::contiguous @@ -5248,6 +5547,7 @@ - name: aten::imag depends: - name: aten::eq + - name: aten::imag - name: aten::is_complex - name: aten::is_nonzero - name: aten::select @@ -5408,9 +5708,11 @@ - name: aten::empty_like - name: aten::empty_strided - name: aten::eq + - name: aten::fill_ - name: aten::is_nonzero - name: aten::resize_ - name: aten::resize_as_ + - name: aten::select - name: aten::to - name: aten::inverse depends: @@ -5575,6 +5877,7 @@ - name: aten::constant_pad_nd - name: aten::div - name: aten::eq + - name: aten::is_complex - name: aten::is_nonzero - name: aten::item - name: aten::min @@ -5589,6 +5892,8 @@ - name: aten::transpose - name: aten::unsqueeze - name: aten::view + - name: aten::view_as_complex + - name: aten::view_as_real - name: aten::istitle depends: - name: aten::eq @@ -5612,6 +5917,22 @@ depends: - name: aten::eq - name: aten::is_nonzero +- name: aten::kaiser_window + depends: + - name: aten::arange + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::kaiser_window + - name: aten::narrow + - name: aten::ones + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to - name: aten::kl_div depends: - name: aten::eq @@ -5631,14 +5952,32 @@ - name: aten::zeros_like - name: aten::kl_div_backward depends: + - name: aten::as_strided_ + - name: aten::copy_ - name: aten::div + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided - name: aten::eq - name: aten::exp - name: aten::expand_as - name: aten::is_nonzero - name: aten::mul - name: aten::neg + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to - name: aten::zeros_like +- name: aten::kron + depends: + - name: aten::copy_ + - name: aten::eq + - name: aten::is_nonzero + - name: aten::kron + - name: aten::permute + - name: aten::reshape + - name: aten::resize_ + - name: aten::tensordot - name: aten::kthvalue depends: - name: aten::clone @@ -5712,35 +6051,19 @@ - name: aten::empty_like - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - - name: aten::is_complex - name: aten::is_nonzero - name: aten::item - name: aten::le - - name: aten::mul - - name: aten::ones - name: aten::resize_ - name: aten::resize_as_ - name: aten::scalar_tensor - name: aten::to - name: aten::le_ depends: - - name: aten::as_strided_ - - name: aten::copy_ - - name: aten::empty - - name: aten::empty_like - - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - - name: aten::is_complex - name: aten::is_nonzero - name: aten::le - - name: aten::mul - - name: aten::ones - - name: aten::resize_ - - name: aten::resize_as_ - name: aten::scalar_tensor - - name: aten::to - name: aten::leaky_relu depends: - name: aten::_empty_affine_quantized @@ -5852,10 +6175,10 @@ - name: aten::linalg_norm depends: - name: aten::abs - - name: aten::add_ - name: aten::copy_ - name: aten::empty - name: aten::eq + - name: aten::fill_ - name: aten::flatten - name: aten::frobenius_norm - name: aten::is_nonzero @@ -5864,12 +6187,21 @@ - name: aten::norm - name: aten::nuclear_norm - name: aten::permute - - name: aten::pow - name: aten::resize_ - name: aten::sum - name: aten::svd - name: aten::to - name: aten::unsqueeze_ +- name: aten::linalg_tensorsolve + depends: + - name: aten::copy_ + - name: aten::eq + - name: aten::is_nonzero + - name: aten::linalg_tensorsolve + - name: aten::movedim + - name: aten::reshape + - name: aten::resize_ + - name: aten::solve - name: aten::linear depends: - name: aten::add_ @@ -5909,6 +6241,7 @@ - name: aten::log depends: - name: aten::as_strided_ + - name: aten::copy_ - name: aten::empty - name: aten::empty_like - name: aten::empty_strided @@ -5918,10 +6251,12 @@ - name: aten::is_nonzero - name: aten::log - name: aten::resize_ + - name: aten::resize_as_ - name: aten::to - name: aten::log10 depends: - name: aten::as_strided_ + - name: aten::copy_ - name: aten::empty - name: aten::empty_like - name: aten::empty_strided @@ -5931,6 +6266,7 @@ - name: aten::is_nonzero - name: aten::log10 - name: aten::resize_ + - name: aten::resize_as_ - name: aten::to - name: aten::log10_ depends: @@ -5958,6 +6294,7 @@ - name: aten::log2 depends: - name: aten::as_strided_ + - name: aten::copy_ - name: aten::empty - name: aten::empty_like - name: aten::empty_strided @@ -5967,6 +6304,7 @@ - name: aten::is_nonzero - name: aten::log2 - name: aten::resize_ + - name: aten::resize_as_ - name: aten::to - name: aten::log2_ depends: @@ -6065,7 +6403,6 @@ - name: aten::add_ - name: aten::all - name: aten::arange - - name: aten::contiguous - name: aten::diagonal - name: aten::eq - name: aten::fill_ @@ -6290,35 +6627,19 @@ - name: aten::empty_like - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - - name: aten::is_complex - name: aten::is_nonzero - name: aten::item - name: aten::lt - - name: aten::mul - - name: aten::ones - name: aten::resize_ - name: aten::resize_as_ - name: aten::scalar_tensor - name: aten::to - name: aten::lt_ depends: - - name: aten::as_strided_ - - name: aten::copy_ - - name: aten::empty - - name: aten::empty_like - - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - - name: aten::is_complex - name: aten::is_nonzero - name: aten::lt - - name: aten::mul - - name: aten::ones - - name: aten::resize_ - - name: aten::resize_as_ - name: aten::scalar_tensor - - name: aten::to - name: aten::lu_solve depends: - name: aten::_lu_solve_helper @@ -6575,6 +6896,7 @@ - name: aten::eq - name: aten::is_nonzero - name: aten::max_pool1d_with_indices + - name: aten::quantized_max_pool1d - name: aten::size - name: aten::squeeze_ - name: aten::max_pool1d_with_indices @@ -6709,15 +7031,18 @@ - name: aten::median depends: - name: aten::clone + - name: aten::contiguous - name: aten::empty - name: aten::eq - - name: aten::fill_ - name: aten::is_nonzero - - name: aten::kthvalue - name: aten::median + - name: aten::resize_ - name: aten::select - name: aten::size - - name: aten::view + - name: aten::squeeze_ + - name: aten::stride + - name: aten::transpose_ + - name: aten::unsqueeze - name: aten::meshgrid depends: - name: aten::eq @@ -6945,11 +7270,8 @@ - name: aten::empty_like - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - - name: aten::is_complex - name: aten::is_nonzero - name: aten::mul - - name: aten::ones - name: aten::permute - name: aten::resize_ - name: aten::resize_as_ @@ -6964,11 +7286,7 @@ - name: aten::empty_like - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - - name: aten::is_complex - name: aten::is_nonzero - - name: aten::mul - - name: aten::ones - name: aten::resize_ - name: aten::resize_as_ - name: aten::scalar_tensor @@ -7048,6 +7366,16 @@ - name: aten::sum - name: aten::topk - name: aten::uniform_ +- name: aten::multiply + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::mul +- name: aten::multiply_ + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::mul_ - name: aten::mv depends: - name: aten::_addmv_impl_ @@ -7112,6 +7440,76 @@ - name: aten::lgamma_ - name: aten::sum - name: aten::unsqueeze +- name: aten::nan_to_num + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::nan_to_num + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::nan_to_num_ + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::nan_to_num +- name: aten::nanmedian + depends: + - name: aten::clone + - name: aten::contiguous + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::nanmedian + - name: aten::resize_ + - name: aten::select + - name: aten::size + - name: aten::squeeze_ + - name: aten::stride + - name: aten::transpose_ + - name: aten::unsqueeze +- name: aten::nanquantile + depends: + - name: aten::all + - name: aten::any + - name: aten::broadcast_tensors + - name: aten::ceil_ + - name: aten::copy_ + - name: aten::empty + - name: aten::eq + - name: aten::flatten + - name: aten::gather + - name: aten::ge + - name: aten::is_nonzero + - name: aten::isnan + - name: aten::item + - name: aten::le + - name: aten::lerp_ + - name: aten::logical_and_ + - name: aten::logical_not_ + - name: aten::lt + - name: aten::masked_fill + - name: aten::masked_fill_ + - name: aten::mul + - name: aten::nanquantile + - name: aten::resize_ + - name: aten::scalar_tensor + - name: aten::size + - name: aten::sort + - name: aten::squeeze_ + - name: aten::sub + - name: aten::sum + - name: aten::to + - name: aten::transpose + - name: aten::transpose_ + - name: aten::unsqueeze + - name: aten::unsqueeze_ + - name: aten::view - name: aten::nansum depends: - name: aten::as_strided @@ -7172,12 +7570,16 @@ depends: - name: aten::_empty_affine_quantized - name: aten::_empty_per_channel_affine_quantized + - name: aten::add + - name: aten::addcmul - name: aten::clone - name: aten::dense_dim - name: aten::empty - name: aten::empty_strided - name: aten::eq - name: aten::is_nonzero + - name: aten::mul + - name: aten::native_batch_norm - name: aten::q_per_channel_axis - name: aten::q_per_channel_scales - name: aten::q_per_channel_zero_points @@ -7186,6 +7588,7 @@ - name: aten::qscheme - name: aten::sparse_dim - name: aten::sparse_resize_and_clear_ + - name: aten::view - name: aten::native_group_norm_backward depends: - name: aten::_empty_affine_quantized @@ -7260,35 +7663,19 @@ - name: aten::empty_like - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - - name: aten::is_complex - name: aten::is_nonzero - name: aten::item - - name: aten::mul - name: aten::ne - - name: aten::ones - name: aten::resize_ - name: aten::resize_as_ - name: aten::scalar_tensor - name: aten::to - name: aten::ne_ depends: - - name: aten::as_strided_ - - name: aten::copy_ - - name: aten::empty - - name: aten::empty_like - - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - - name: aten::is_complex - name: aten::is_nonzero - - name: aten::mul - name: aten::ne - - name: aten::ones - - name: aten::resize_ - - name: aten::resize_as_ - name: aten::scalar_tensor - - name: aten::to - name: aten::neg depends: - name: aten::as_strided_ @@ -7309,27 +7696,28 @@ - name: aten::neg - name: aten::negative depends: - - name: aten::as_strided_ - - name: aten::copy_ - - name: aten::empty - - name: aten::empty_like - - name: aten::empty_strided - name: aten::eq - name: aten::is_nonzero - name: aten::neg - - name: aten::resize_ - - name: aten::resize_as_ - - name: aten::to - name: aten::negative_ depends: - name: aten::eq - name: aten::is_nonzero - - name: aten::neg + - name: aten::neg_ - name: aten::new_empty depends: - name: aten::empty - name: aten::eq - name: aten::is_nonzero + - name: aten::new_empty + - name: aten::permute +- name: aten::new_empty_strided + depends: + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::new_empty_strided + - name: aten::permute - name: aten::new_full depends: - name: aten::eq @@ -7339,6 +7727,8 @@ depends: - name: aten::eq - name: aten::is_nonzero + - name: aten::new_zeros + - name: aten::permute - name: aten::zeros - name: aten::nextafter depends: @@ -7494,13 +7884,15 @@ - name: aten::ne_ - name: aten::nuclear_norm depends: + - name: aten::copy_ + - name: aten::empty - name: aten::eq - name: aten::is_floating_point - name: aten::is_leaf - name: aten::is_nonzero - name: aten::nuclear_norm - name: aten::permute - - name: aten::set_ + - name: aten::resize_ - name: aten::sum - name: aten::svd - name: aten::to @@ -7582,8 +7974,10 @@ - name: aten::outer depends: - name: aten::eq - - name: aten::ger - name: aten::is_nonzero + - name: aten::mul + - name: aten::reshape + - name: aten::size - name: aten::output_nr depends: - name: aten::eq @@ -7625,7 +8019,7 @@ - name: aten::set_ - name: aten::pinverse depends: - - name: aten::diag_embed + - name: aten::conj - name: aten::empty - name: aten::eq - name: aten::gt @@ -7635,7 +8029,9 @@ - name: aten::narrow - name: aten::reciprocal - name: aten::svd + - name: aten::to - name: aten::transpose + - name: aten::unsqueeze - name: aten::where - name: aten::zeros - name: aten::pixel_shuffle @@ -7647,8 +8043,16 @@ - name: aten::size - name: aten::poisson depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided - name: aten::eq - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to - name: aten::zeros - name: aten::poisson_nll_loss depends: @@ -7709,12 +8113,10 @@ - name: aten::empty_strided - name: aten::eq - name: aten::fill_ - - name: aten::is_complex - name: aten::is_floating_point - name: aten::is_leaf - name: aten::is_nonzero - - name: aten::mul - - name: aten::ones + - name: aten::item - name: aten::permute - name: aten::pow - name: aten::resize_ @@ -7734,6 +8136,7 @@ - name: aten::eq - name: aten::fill_ - name: aten::is_nonzero + - name: aten::item - name: aten::resize_ - name: aten::resize_as_ - name: aten::result_type @@ -7833,29 +8236,40 @@ - name: aten::quantile depends: - name: aten::all - - name: aten::ceil + - name: aten::any + - name: aten::broadcast_tensors + - name: aten::ceil_ - name: aten::copy_ - name: aten::empty - name: aten::eq - name: aten::flatten - - name: aten::floor + - name: aten::gather - name: aten::ge - - name: aten::index_select - name: aten::is_nonzero + - name: aten::isnan - name: aten::item - name: aten::le - name: aten::lerp_ - name: aten::logical_and_ + - name: aten::logical_not_ + - name: aten::lt + - name: aten::masked_fill + - name: aten::masked_fill_ - name: aten::mul - - name: aten::permute - name: aten::quantile - - name: aten::reshape + - name: aten::resize_ - name: aten::scalar_tensor - name: aten::size - name: aten::sort + - name: aten::squeeze_ - name: aten::sub + - name: aten::sum - name: aten::to + - name: aten::transpose + - name: aten::transpose_ + - name: aten::unsqueeze - name: aten::unsqueeze_ + - name: aten::view - name: aten::quantize_per_channel depends: - name: aten::contiguous @@ -7952,6 +8366,13 @@ - name: aten::tanh - name: aten::tanh_ - name: aten::unsafe_chunk +- name: aten::quantized_max_pool1d + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::quantized_max_pool2d + - name: aten::squeeze + - name: aten::unsqueeze - name: aten::quantized_max_pool2d depends: - name: aten::_empty_affine_quantized @@ -7977,22 +8398,13 @@ - name: aten::tanh - name: aten::rad2deg depends: - - name: aten::as_strided_ - - name: aten::copy_ - name: aten::empty - - name: aten::empty_like - - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - name: aten::is_complex - name: aten::is_nonzero - name: aten::mul - - name: aten::ones - name: aten::rad2deg - - name: aten::resize_ - - name: aten::resize_as_ - name: aten::scalar_tensor - - name: aten::to - name: aten::rad2deg_ depends: - name: aten::eq @@ -8070,11 +8482,17 @@ - name: aten::is_nonzero - name: aten::range - name: aten::resize_ +- name: aten::ravel + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::reshape - name: aten::real depends: - name: aten::eq - name: aten::is_complex - name: aten::is_nonzero + - name: aten::real - name: aten::select - name: aten::view_as_real - name: aten::reciprocal @@ -8097,6 +8515,10 @@ - name: aten::eq - name: aten::is_nonzero - name: aten::reciprocal +- name: aten::record_stream + depends: + - name: aten::eq + - name: aten::is_nonzero - name: aten::refine_names depends: - name: aten::alias @@ -8181,11 +8603,7 @@ - name: aten::empty_like - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - - name: aten::is_complex - name: aten::is_nonzero - - name: aten::mul - - name: aten::ones - name: aten::resize_ - name: aten::resize_as_ - name: aten::scalar_tensor @@ -8198,11 +8616,7 @@ - name: aten::empty_like - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - - name: aten::is_complex - name: aten::is_nonzero - - name: aten::mul - - name: aten::ones - name: aten::resize_ - name: aten::resize_as_ - name: aten::scalar_tensor @@ -8349,10 +8763,6 @@ - name: aten::size - name: aten::zero_ - name: aten::zeros_like -- name: aten::requires_grad - depends: - - name: aten::eq - - name: aten::is_nonzero - name: aten::requires_grad_ depends: - name: aten::eq @@ -8386,22 +8796,10 @@ - name: aten::sparse_dim - name: aten::result_type depends: - - name: aten::as_strided_ - - name: aten::copy_ - - name: aten::empty - - name: aten::empty_like - - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - - name: aten::is_complex - name: aten::is_nonzero - - name: aten::mul - - name: aten::ones - - name: aten::resize_ - - name: aten::resize_as_ - name: aten::result_type - name: aten::scalar_tensor - - name: aten::to - name: aten::retain_grad depends: - name: aten::eq @@ -8523,6 +8921,11 @@ - name: aten::eq - name: aten::is_nonzero - name: aten::round +- name: aten::row_stack + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::vstack - name: aten::rowwise_prune depends: - name: aten::contiguous @@ -8549,69 +8952,37 @@ - name: aten::rrelu_with_noise depends: - name: aten::add - - name: aten::as_strided_ - name: aten::contiguous - name: aten::copy_ - name: aten::div - - name: aten::empty - name: aten::empty_like - - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - - name: aten::is_complex - name: aten::is_nonzero - name: aten::item - name: aten::leaky_relu - - name: aten::mul - - name: aten::ones - - name: aten::resize_ - - name: aten::resize_as_ - name: aten::scalar_tensor - - name: aten::to - name: aten::rrelu_with_noise_ depends: - name: aten::add - - name: aten::as_strided_ - name: aten::contiguous - name: aten::copy_ - name: aten::div - - name: aten::empty - - name: aten::empty_like - - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - - name: aten::is_complex - name: aten::is_nonzero - name: aten::item - name: aten::leaky_relu - - name: aten::mul - - name: aten::ones - - name: aten::resize_ - - name: aten::resize_as_ - name: aten::scalar_tensor - - name: aten::to - name: aten::rrelu_with_noise_backward depends: - name: aten::add - - name: aten::as_strided_ - - name: aten::copy_ - name: aten::div - - name: aten::empty - - name: aten::empty_like - - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - - name: aten::is_complex - name: aten::is_nonzero - name: aten::item - name: aten::leaky_relu_backward - name: aten::mul - - name: aten::ones - - name: aten::resize_ - - name: aten::resize_as_ - name: aten::scalar_tensor - name: aten::sub - - name: aten::to - name: aten::rsplit depends: - name: aten::eq @@ -8648,11 +9019,7 @@ - name: aten::empty_like - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - - name: aten::is_complex - name: aten::is_nonzero - - name: aten::mul - - name: aten::ones - name: aten::permute - name: aten::resize_ - name: aten::resize_as_ @@ -8715,20 +9082,11 @@ - name: aten::to - name: aten::searchsorted depends: - - name: aten::as_strided_ - name: aten::contiguous - - name: aten::copy_ - name: aten::empty - - name: aten::empty_like - - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - - name: aten::is_complex - name: aten::is_nonzero - - name: aten::mul - - name: aten::ones - name: aten::resize_ - - name: aten::resize_as_ - name: aten::scalar_tensor - name: aten::to - name: aten::select @@ -8782,6 +9140,25 @@ depends: - name: aten::eq - name: aten::is_nonzero +- name: aten::sgn + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::sgn + - name: aten::to +- name: aten::sgn_ + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::sgn - name: aten::sigmoid depends: - name: aten::_empty_affine_quantized @@ -8828,6 +9205,7 @@ - name: aten::empty_like - name: aten::empty_strided - name: aten::eq + - name: aten::is_complex - name: aten::is_nonzero - name: aten::resize_ - name: aten::resize_as_ @@ -8886,12 +9264,14 @@ - name: aten::sin depends: - name: aten::as_strided_ + - name: aten::copy_ - name: aten::empty - name: aten::empty_like - name: aten::empty_strided - name: aten::eq - name: aten::is_nonzero - name: aten::resize_ + - name: aten::resize_as_ - name: aten::sin - name: aten::to - name: aten::sin_ @@ -8945,7 +9325,6 @@ - name: aten::add_ - name: aten::all - name: aten::arange - - name: aten::contiguous - name: aten::diagonal - name: aten::eq - name: aten::fmod_ @@ -8991,6 +9370,7 @@ - name: aten::bmm - name: aten::contiguous - name: aten::copy_ + - name: aten::detach - name: aten::empty - name: aten::eq - name: aten::is_nonzero @@ -9109,6 +9489,7 @@ - name: aten::sspaddmm - name: aten::smooth_l1_loss depends: + - name: aten::abs_ - name: aten::as_strided_ - name: aten::copy_ - name: aten::empty @@ -9122,20 +9503,26 @@ - name: aten::resize_ - name: aten::resize_as_ - name: aten::smooth_l1_loss + - name: aten::sub - name: aten::sum - name: aten::to - name: aten::smooth_l1_loss_backward depends: - name: aten::as_strided_ - name: aten::copy_ + - name: aten::div - name: aten::empty - name: aten::empty_like - name: aten::empty_strided - name: aten::eq - name: aten::is_nonzero + - name: aten::l1_loss_backward + - name: aten::mul_ - name: aten::resize_ - name: aten::resize_as_ + - name: aten::sign_ - name: aten::smooth_l1_loss_backward + - name: aten::sub - name: aten::to - name: aten::zeros_like - name: aten::soft_margin_loss @@ -9238,28 +9625,25 @@ - name: aten::size - name: aten::sort depends: - - name: aten::_copy_from - name: aten::_make_per_tensor_quantized_tensor + - name: aten::arange + - name: aten::as_strided - name: aten::as_strided_ - name: aten::copy_ - - name: aten::copy_sparse_to_sparse_ - name: aten::empty - name: aten::empty_like - name: aten::empty_strided - name: aten::eq - name: aten::int_repr - - name: aten::is_complex - name: aten::is_nonzero - name: aten::q_scale - name: aten::q_zero_point - - name: aten::qscheme - name: aten::resize_ - - name: aten::resize_as_ - - name: aten::set_quantizer_ - name: aten::size - name: aten::sort - name: aten::stride - name: aten::to + - name: aten::zero_ - name: aten::sparse_coo_tensor depends: - name: aten::_sparse_coo_tensor_with_dims @@ -9347,6 +9731,7 @@ - name: aten::size - name: aten::squeeze - name: aten::to + - name: aten::view - name: aten::squeeze_ depends: - name: aten::as_strided_ @@ -9414,19 +9799,24 @@ - name: aten::to - name: aten::stft depends: + - name: aten::_fft_with_size - name: aten::as_strided - name: aten::copy_ - name: aten::eq - name: aten::fill_ + - name: aten::is_complex - name: aten::is_nonzero - name: aten::mul - name: aten::narrow - - name: aten::rfft + - name: aten::reshape - name: aten::size + - name: aten::squeeze - name: aten::squeeze_ - name: aten::stride - name: aten::transpose_ - name: aten::unsqueeze + - name: aten::view_as_complex + - name: aten::view_as_real - name: aten::zeros - name: aten::stride depends: @@ -9444,11 +9834,7 @@ - name: aten::empty_like - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - - name: aten::is_complex - name: aten::is_nonzero - - name: aten::mul - - name: aten::ones - name: aten::permute - name: aten::resize_ - name: aten::resize_as_ @@ -9464,11 +9850,7 @@ - name: aten::empty_like - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - - name: aten::is_complex - name: aten::is_nonzero - - name: aten::mul - - name: aten::ones - name: aten::resize_ - name: aten::resize_as_ - name: aten::scalar_tensor @@ -9487,6 +9869,7 @@ depends: - name: aten::as_strided - name: aten::as_strided_ + - name: aten::clone - name: aten::copy_ - name: aten::empty - name: aten::empty_like @@ -9539,25 +9922,11 @@ - name: aten::transpose_ - name: aten::take depends: - - name: aten::_copy_from - - name: aten::as_strided_ - - name: aten::copy_ - - name: aten::copy_sparse_to_sparse_ + - name: aten::contiguous - name: aten::empty - - name: aten::empty_like - - name: aten::empty_strided - name: aten::eq - - name: aten::is_complex - name: aten::is_nonzero - - name: aten::q_scale - - name: aten::q_zero_point - - name: aten::qscheme - name: aten::resize_ - - name: aten::resize_as_ - - name: aten::set_quantizer_ - - name: aten::size - - name: aten::stride - - name: aten::to - name: aten::take_backward depends: - name: aten::eq @@ -9567,6 +9936,7 @@ - name: aten::tan depends: - name: aten::as_strided_ + - name: aten::copy_ - name: aten::empty - name: aten::empty_like - name: aten::empty_strided @@ -9575,6 +9945,7 @@ - name: aten::is_leaf - name: aten::is_nonzero - name: aten::resize_ + - name: aten::resize_as_ - name: aten::tan - name: aten::to - name: aten::tan_ @@ -9617,8 +9988,17 @@ - name: aten::resize_ - name: aten::resize_as_ - name: aten::to +- name: aten::tensor_split + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::permute + - name: aten::size + - name: aten::slice + - name: aten::tensor_split - name: aten::tensordot depends: + - name: aten::copy_ - name: aten::eq - name: aten::is_floating_point - name: aten::is_leaf @@ -9656,6 +10036,7 @@ - name: aten::addmm_ - name: aten::contiguous - name: aten::copy_ + - name: aten::detach - name: aten::empty - name: aten::eq - name: aten::is_nonzero @@ -9666,7 +10047,6 @@ - name: aten::size - name: aten::unsqueeze - name: aten::view - - name: aten::zero_ - name: aten::thnn_conv_depthwise2d depends: - name: aten::eq @@ -9786,9 +10166,11 @@ - name: aten::zero_ - name: aten::trace depends: + - name: aten::empty - name: aten::eq - name: aten::is_nonzero - - name: aten::scalar_tensor + - name: aten::size + - name: aten::stride - name: aten::trace_backward depends: - name: aten::arange @@ -9924,17 +10306,9 @@ - name: aten::triu_indices - name: aten::true_divide depends: - - name: aten::as_strided_ - - name: aten::copy_ - name: aten::div - - name: aten::empty - - name: aten::empty_like - - name: aten::empty_strided - name: aten::eq - name: aten::is_nonzero - - name: aten::resize_ - - name: aten::resize_as_ - - name: aten::to - name: aten::true_divide_ depends: - name: aten::div_ @@ -10350,17 +10724,16 @@ - name: aten::view - name: aten::view_as_complex depends: - - name: aten::empty - name: aten::eq - name: aten::is_nonzero - - name: aten::set_ + - name: aten::permute + - name: aten::view_as_complex - name: aten::view_as_real depends: - - name: aten::empty - name: aten::eq - name: aten::is_complex - name: aten::is_nonzero - - name: aten::set_ + - name: aten::view_as_real - name: aten::vstack depends: - name: aten::atleast_2d @@ -10370,23 +10743,11 @@ - name: aten::where depends: - name: aten::_s_where - - name: aten::as_strided_ - - name: aten::copy_ - - name: aten::empty - - name: aten::empty_like - - name: aten::empty_strided - name: aten::eq - name: aten::expand - - name: aten::fill_ - - name: aten::is_complex - name: aten::is_nonzero - - name: aten::mul - name: aten::nonzero_numpy - - name: aten::ones - - name: aten::resize_ - - name: aten::resize_as_ - name: aten::scalar_tensor - - name: aten::to - name: aten::where - name: aten::zero_ depends: @@ -10464,6 +10825,7 @@ - name: aten::is_nonzero - name: quantized::add depends: + - name: aten::_empty_affine_quantized - name: aten::as_strided_ - name: aten::contiguous - name: aten::copy_ @@ -10478,6 +10840,7 @@ - name: aten::resize_ - name: aten::resize_as_ - name: aten::set_quantizer_ + - name: aten::size - name: aten::to - name: quantized::add_out depends: @@ -10970,8 +11333,24 @@ - name: aten::unsqueeze - name: quantized::conv_transpose1d_prepack depends: + - name: aten::_empty_affine_quantized + - name: aten::contiguous + - name: aten::empty - name: aten::eq + - name: aten::equal - name: aten::is_nonzero + - name: aten::item + - name: aten::mul + - name: aten::q_per_channel_scales + - name: aten::q_per_channel_zero_points + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::select + - name: aten::size + - name: aten::to + - name: aten::unsqueeze + - name: aten::zeros - name: quantized::conv_transpose1d_unpack depends: - name: aten::clone @@ -11000,8 +11379,23 @@ - name: aten::is_nonzero - name: quantized::conv_transpose2d_prepack depends: + - name: aten::_empty_affine_quantized + - name: aten::contiguous + - name: aten::empty - name: aten::eq + - name: aten::equal - name: aten::is_nonzero + - name: aten::item + - name: aten::mul + - name: aten::q_per_channel_scales + - name: aten::q_per_channel_zero_points + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::select + - name: aten::size + - name: aten::to + - name: aten::zeros - name: quantized::conv_transpose2d_stride depends: - name: aten::eq @@ -11014,6 +11408,42 @@ depends: - name: aten::eq - name: aten::is_nonzero +- name: quantized::conv_transpose3d + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: quantized::conv_transpose3d_dilation + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: quantized::conv_transpose3d_groups + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: quantized::conv_transpose3d_output_padding + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: quantized::conv_transpose3d_padding + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: quantized::conv_transpose3d_prepack + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: quantized::conv_transpose3d_stride + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: quantized::conv_transpose3d_transpose + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: quantized::conv_transpose3d_unpack + depends: + - name: aten::eq + - name: aten::is_nonzero - name: quantized::conv_unpack depends: - name: aten::eq @@ -11035,10 +11465,13 @@ - name: aten::to - name: quantized::embedding_bag_2bit_prepack depends: + - name: aten::choose_qparams_optimized - name: aten::contiguous - name: aten::empty - name: aten::eq - name: aten::is_nonzero + - name: aten::item + - name: aten::select - name: aten::size - name: quantized::embedding_bag_2bit_unpack depends: @@ -11046,12 +11479,19 @@ - name: aten::eq - name: aten::is_nonzero - name: aten::size +- name: quantized::embedding_bag_4bit + depends: + - name: aten::eq + - name: aten::is_nonzero - name: quantized::embedding_bag_4bit_prepack depends: + - name: aten::choose_qparams_optimized - name: aten::contiguous - name: aten::empty - name: aten::eq - name: aten::is_nonzero + - name: aten::item + - name: aten::select - name: aten::size - name: quantized::embedding_bag_4bit_rowwise_offsets depends: @@ -11060,7 +11500,6 @@ - name: aten::eq - name: aten::is_nonzero - name: aten::size - - name: aten::to - name: quantized::embedding_bag_4bit_unpack depends: - name: aten::empty @@ -11080,6 +11519,7 @@ - name: aten::size - name: quantized::embedding_bag_byte_rowwise_offsets depends: + - name: aten::contiguous - name: aten::empty - name: aten::eq - name: aten::is_nonzero @@ -11155,6 +11595,22 @@ - name: aten::is_nonzero - name: aten::q_scale - name: aten::q_zero_point +- name: quantized::leaky_relu + depends: + - name: aten::_empty_affine_quantized + - name: aten::as_strided_ + - name: aten::contiguous + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to - name: quantized::linear depends: - name: aten::eq @@ -11169,8 +11625,28 @@ - name: aten::is_nonzero - name: quantized::linear_prepack depends: + - name: aten::_empty_affine_quantized + - name: aten::contiguous + - name: aten::copy_ + - name: aten::empty - name: aten::eq + - name: aten::equal - name: aten::is_nonzero + - name: aten::item + - name: aten::max + - name: aten::min + - name: aten::mul + - name: aten::q_per_channel_scales + - name: aten::q_per_channel_zero_points + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::quantize_per_tensor + - name: aten::resize_ + - name: aten::select + - name: aten::size + - name: aten::to + - name: aten::zeros - name: quantized::linear_prepack_fp16 depends: - name: aten::eq @@ -11232,16 +11708,16 @@ depends: - name: aten::eq - name: aten::is_nonzero +- name: quantized::max_pool1d + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::quantized_max_pool1d - name: quantized::max_pool2d depends: - - name: aten::_empty_affine_quantized - - name: aten::contiguous - name: aten::eq - name: aten::is_nonzero - - name: aten::max_pool2d - - name: aten::q_scale - - name: aten::q_zero_point - - name: aten::size + - name: aten::quantized_max_pool2d - name: quantized::mul depends: - name: aten::_empty_affine_quantized @@ -11432,6 +11908,23 @@ - name: aten::resize_ - name: aten::resize_as_ - name: aten::to +- name: quantized::sigmoid + depends: + - name: aten::_empty_affine_quantized + - name: aten::as_strided_ + - name: aten::contiguous + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::size + - name: aten::to - name: quantized::threshold depends: - name: aten::_empty_affine_quantized From eb8331e759463018d341bc21b343ec24a16735b5 Mon Sep 17 00:00:00 2001 From: Natalia Gimelshein Date: Thu, 12 Nov 2020 19:29:32 -0800 Subject: [PATCH 86/93] Revert D24524219: Remove `balance` and `devices` parameter from Pipe. Test Plan: revert-hammer Differential Revision: D24524219 (https://github.com/pytorch/pytorch/commit/8da75763032fb74f16c92fd9a8c98e69cd0b9ce3) Original commit changeset: 9973172c2bb7 fbshipit-source-id: b187c80270adb2a412e3882863a2d7de2a52ed56 --- .../_pipeline/sync/skip/test_gpipe.py | 6 +- .../_pipeline/sync/skip/test_leak.py | 2 +- test/distributed/_pipeline/sync/test_bugs.py | 20 +-- .../_pipeline/sync/test_inplace.py | 2 +- test/distributed/_pipeline/sync/test_pipe.py | 167 +++++++++++------- .../_pipeline/sync/test_transparency.py | 2 +- torch/distributed/_pipeline/sync/pipe.py | 144 ++++++++++----- .../distributed/pipeline/__init__.py | 0 .../_internal/distributed/pipeline/utils.py | 21 --- 9 files changed, 209 insertions(+), 155 deletions(-) delete mode 100644 torch/testing/_internal/distributed/pipeline/__init__.py delete mode 100644 torch/testing/_internal/distributed/pipeline/utils.py diff --git a/test/distributed/_pipeline/sync/skip/test_gpipe.py b/test/distributed/_pipeline/sync/skip/test_gpipe.py index 96ecd84e0d18..293a263439bc 100644 --- a/test/distributed/_pipeline/sync/skip/test_gpipe.py +++ b/test/distributed/_pipeline/sync/skip/test_gpipe.py @@ -11,7 +11,6 @@ from torch.distributed._pipeline.sync import Pipe from torch.distributed._pipeline.sync.skip import pop, skippable, stash from torch.distributed._pipeline.sync.skip.portal import PortalBlue, PortalCopy, PortalOrange -from torch.testing._internal.distributed.pipeline.utils import convert_to_balance @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") @@ -53,8 +52,7 @@ def forward(self, input): return output model = nn.Sequential(Layer1(), Layer2(), Layer3()) - model = convert_to_balance(model, balance) - model = Pipe(model, chunks=3, checkpoint=checkpoint) + model = Pipe(model, balance, chunks=3, checkpoint=checkpoint) in_device = model.devices[0] out_device = model.devices[-1] @@ -83,7 +81,7 @@ def forward(self, input): return input model = nn.Sequential(Stash(), Pop()) - model = Pipe(model, chunks=5) + model = Pipe(model, [1, 1], devices=["cpu", "cpu"], chunks=5) input = torch.rand(10, requires_grad=True) output = model(input) diff --git a/test/distributed/_pipeline/sync/skip/test_leak.py b/test/distributed/_pipeline/sync/skip/test_leak.py index 31c4ea13b9f1..89e39aa9cedb 100644 --- a/test/distributed/_pipeline/sync/skip/test_leak.py +++ b/test/distributed/_pipeline/sync/skip/test_leak.py @@ -91,7 +91,7 @@ def forward(self, input): return self.F.apply(input) model = nn.Sequential(NoPortalTensorAtBackward(), stash_, pop_) - model = Pipe(model, chunks=2, checkpoint=checkpoint) + model = Pipe(model, balance=[2, 1], devices=["cpu", "cpu"], chunks=2, checkpoint=checkpoint) input = torch.rand(10, requires_grad=True) diff --git a/test/distributed/_pipeline/sync/test_bugs.py b/test/distributed/_pipeline/sync/test_bugs.py index 4f5346a837b5..c3152745b5bb 100644 --- a/test/distributed/_pipeline/sync/test_bugs.py +++ b/test/distributed/_pipeline/sync/test_bugs.py @@ -37,7 +37,7 @@ def forward(self, input): return Identity.apply(input) model = nn.Sequential(M(), M()) - model = Pipe(model, checkpoint="always") + model = Pipe(model, [1, 1], devices=["cpu", "cpu"], checkpoint="always") x = torch.rand(42) y = model(x) @@ -62,7 +62,7 @@ def forward(self, x): raise ExpectedException() model = nn.Sequential(Pass(), Pass(), Raise()) - model = Pipe(model, chunks=3) + model = Pipe(model, [1, 1, 1], devices=["cpu", "cpu", "cpu"], chunks=3) with pytest.raises(ExpectedException): model(torch.rand(3)) @@ -86,28 +86,18 @@ def backward(ctx, grad): return grad class Layer1(nn.Module): - def __init__(self): - super().__init__() - self.ones = nn.Parameter(torch.ones(32, 3, 32, 32, requires_grad=True)) - def forward(self, pair): a, b = pair - a = a * self.ones return a * 1, b * 2, b * 3 class Layer2(nn.Module): - def __init__(self): - super().__init__() - self.ones = nn.Parameter(torch.ones(32, 3, 32, 32, requires_grad=True)) - def forward(self, triple): a, b, c = triple - a = a * self.ones b = Sleep.apply(b) return a + b + c - model = nn.Sequential(Layer1().cuda(0), Layer2().cuda(1)) - model = Pipe(model, chunks=32, checkpoint="never") + model = nn.Sequential(Layer1(), Layer2()) + model = Pipe(model, [1, 1], devices=[0, 1], chunks=32, checkpoint="never") a = torch.rand(1024, 3, 32, 32, device=0, requires_grad=True) b = torch.rand(1024, 3, 32, 32, device=0, requires_grad=True) @@ -131,7 +121,7 @@ def forward(self, x): model = nn.Sequential(Dropouts(), Dropouts()) x = torch.rand(10, 10, requires_grad=True) - model = Pipe(model, chunks=10, checkpoint="always") + model = Pipe(model, [1, 1], devices=["cpu", "cpu"], chunks=10, checkpoint="always") y = model(x) y.norm().backward() diff --git a/test/distributed/_pipeline/sync/test_inplace.py b/test/distributed/_pipeline/sync/test_inplace.py index 17b3dac4eca8..185ad8706054 100644 --- a/test/distributed/_pipeline/sync/test_inplace.py +++ b/test/distributed/_pipeline/sync/test_inplace.py @@ -13,7 +13,7 @@ def test_inplace_on_requires_grad(): model = nn.Sequential(nn.Linear(1, 1), nn.ReLU(inplace=True)) - model = Pipe(model, checkpoint="always") + model = Pipe(model, [1, 1], devices=["cpu", "cpu"], checkpoint="always") x = torch.rand(1) y = model(x) diff --git a/test/distributed/_pipeline/sync/test_pipe.py b/test/distributed/_pipeline/sync/test_pipe.py index 9c2964940576..d7915733adc0 100644 --- a/test/distributed/_pipeline/sync/test_pipe.py +++ b/test/distributed/_pipeline/sync/test_pipe.py @@ -19,7 +19,7 @@ def test_parameters(): model = nn.Sequential(nn.Linear(1, 1)) - pipe = Pipe(model, chunks=1) + pipe = Pipe(model, balance=[1], devices=["cpu"], chunks=1) assert list(pipe.parameters()) != [] @@ -32,8 +32,9 @@ def __str__(self): return self.value model = nn.Sequential(nn.Linear(1, 1)) - pipe = Pipe(model, chunks=42.000, checkpoint=MyString("always")) + pipe = Pipe(model, balance=(1,), devices=("cpu",), chunks=42.000, checkpoint=MyString("always")) + assert pipe.balance == [1] assert pipe.devices == [torch.device("cpu")] assert pipe.chunks == 42 assert isinstance(pipe.chunks, int) @@ -41,12 +42,13 @@ def __str__(self): assert isinstance(pipe.checkpoint, str) -def test_sequential_like(): +@pytest.mark.parametrize("balance", [[2], [1, 1]]) +def test_sequential_like(balance): a = nn.Linear(1, 1) b = nn.Linear(1, 1) model = nn.Sequential(a, b) - model = Pipe(model) + model = Pipe(model, balance, devices=["cpu", "cpu"]) assert len(model) == 2 assert list(model) == [a, b] @@ -59,18 +61,54 @@ def test_sequential_like(): assert model[-1] is b assert model[-2] is a + +def test_balance_wrong_length(): + a = nn.Linear(1, 1) + b = nn.Linear(1, 1) + + model = nn.Sequential(a, b) + + with pytest.raises(ValueError): + Pipe(model, balance=[1]) + + with pytest.raises(ValueError): + Pipe(model, balance=[3]) + + +def test_balance_less_than_1(): + a = nn.Linear(1, 1) + b = nn.Linear(1, 1) + + model = nn.Sequential(a, b) + + with pytest.raises(ValueError): + Pipe(model, balance=[0, 2]) + + with pytest.raises(ValueError): + Pipe(model, balance=[-1, 3]) + + def test_chunks_less_than_1(): model = nn.Sequential(nn.Linear(1, 1)) with pytest.raises(ValueError): - Pipe(model, chunks=0) + Pipe(model, balance=[1], devices=["cpu"], chunks=0) with pytest.raises(ValueError): - Pipe(model, chunks=-1) + Pipe(model, balance=[1], devices=["cpu"], chunks=-1) + + +def test_too_few_devices(): + model = nn.Sequential(nn.Linear(1, 1), nn.Linear(1, 1), nn.Linear(1, 1), nn.Linear(1, 1)) + + with pytest.raises(IndexError): + # len(balance) > len(devices) + model = Pipe(model, balance=[1, 1, 1, 1], devices=["cpu"]) + def test_batch_size_indivisible(): model = nn.Sequential(nn.Linear(1, 1)) - model = Pipe(model, chunks=4) + model = Pipe(model, balance=[1], devices=["cpu"], chunks=4) with pytest.warns(None) as record: model(torch.rand(7, 1)) @@ -81,7 +119,7 @@ def test_batch_size_indivisible(): def test_batch_size_small(): model = nn.Sequential(nn.Linear(1, 1)) - model = Pipe(model, chunks=4) + model = Pipe(model, balance=[1], devices=["cpu"], chunks=4) with pytest.warns(None) as record: model(torch.rand(2, 1)) @@ -111,9 +149,9 @@ def count_grad_fn(grad_fn, name, visited=None): model = nn.Sequential(nn.Linear(1, 1)) input = torch.rand(2, 1) - always = Pipe(model, chunks=2, checkpoint="always") - except_last = Pipe(model, chunks=2, checkpoint="except_last") - never = Pipe(model, chunks=2, checkpoint="never") + always = Pipe(model, balance=[1], devices=["cpu"], chunks=2, checkpoint="always") + except_last = Pipe(model, balance=[1], devices=["cpu"], chunks=2, checkpoint="except_last") + never = Pipe(model, balance=[1], devices=["cpu"], chunks=2, checkpoint="never") always_output = always(input) except_last_output = except_last(input) @@ -128,21 +166,21 @@ def test_checkpoint_mode_invalid(): model = nn.Sequential(nn.Linear(1, 1)) with pytest.raises(ValueError, match="checkpoint is not one of 'always', 'except_last', or 'never'"): - Pipe(model, chunks=2, checkpoint="INVALID_CHECKPOINT") + Pipe(model, balance=[1], devices=["cpu"], chunks=2, checkpoint="INVALID_CHECKPOINT") def test_checkpoint_mode_when_chunks_1(): model = nn.Sequential(nn.Linear(1, 1)) # All checkpoint modes are fine. - Pipe(model, chunks=1, checkpoint="except_last") - Pipe(model, chunks=1, checkpoint="always") - Pipe(model, chunks=1, checkpoint="never") + Pipe(model, balance=[1], devices=["cpu"], chunks=1, checkpoint="except_last") + Pipe(model, balance=[1], devices=["cpu"], chunks=1, checkpoint="always") + Pipe(model, balance=[1], devices=["cpu"], chunks=1, checkpoint="never") def test_checkpoint_eval(): model = nn.Sequential(nn.Linear(1, 1)) - model = Pipe(model, chunks=2) + model = Pipe(model, balance=[1], devices=["cpu"], chunks=2) input = torch.rand(2, 1) def find_grad_fn(grad_fn, name): @@ -176,7 +214,7 @@ def forward(self, input): return input[0] * 2 model = nn.Sequential(ForkNonFloat(), JoinNonFloat()) - model = Pipe(model, chunks=1, checkpoint="always") + model = Pipe(model, balance=[1, 1], devices=["cpu", "cpu"], chunks=1, checkpoint="always") input = torch.rand(1, requires_grad=True) output = model(input) @@ -185,7 +223,7 @@ def forward(self, input): def test_no_grad(): model = nn.Sequential(nn.Linear(1, 1)) - model = Pipe(model, chunks=2) + model = Pipe(model, balance=[1], devices=["cpu"], chunks=2) input = torch.rand(2, 1) latent = None @@ -215,7 +253,7 @@ def forward(self, *_): raise ExpectedException() model = nn.Sequential(Raise()) - model = Pipe(model, chunks=1) + model = Pipe(model, balance=[1], devices=["cpu"], chunks=1) with pytest.raises(ExpectedException): model(torch.rand(1)) @@ -249,7 +287,7 @@ def forward(self, x): raise ExpectedException() model = nn.Sequential(Pass(), Pass(), Counter(), Raise()) - model = Pipe(model, chunks=3) + model = Pipe(model, [1, 1, 1, 1], devices=["cpu", "cpu", "cpu", "cpu"], chunks=3) with pytest.raises(ExpectedException): model(torch.rand(3)) @@ -270,7 +308,7 @@ def forward(self, a_and_b): return (self.fc_a(a), self.fc_b(b)) model = nn.Sequential(Two()) - model = Pipe(model, chunks=2) + model = Pipe(model, balance=[1], devices=["cpu"], chunks=2) a = torch.rand(10, 1, requires_grad=True) b = torch.rand(10, 1, requires_grad=True) @@ -294,7 +332,7 @@ def forward(self, only_a): return (self.fc(a),) model = nn.Sequential(One()) - model = Pipe(model, chunks=2) + model = Pipe(model, balance=[1], devices=["cpu"], chunks=2) a = torch.rand(10, 1, requires_grad=True) @@ -308,7 +346,7 @@ def forward(self, only_a): def test_input_varargs(): model = nn.Sequential(nn.Linear(1, 1)) - model = Pipe(model) + model = Pipe(model, balance=[1], devices=["cpu"]) a = torch.rand(1) b = torch.rand(1) @@ -324,7 +362,7 @@ def forward(self, _): return "hello" model = nn.Sequential(NonTensor()) - model = Pipe(model) + model = Pipe(model, balance=[1], devices=["cpu"]) x = torch.rand(1) # TypeError: expected Tensor as element 0 in argument 0, but got str @@ -342,7 +380,7 @@ def forward(self, x): return (x, "hello") model = nn.Sequential(NonTensorTuple()) - model = Pipe(model) + model = Pipe(model, balance=[1], devices=["cpu"]) x = torch.rand(1) # TypeError: CheckpointBackward.forward: expected Variable (got str) for return value 1 @@ -359,7 +397,7 @@ def test_deferred_batch_norm(checkpoint): bn = nn.BatchNorm2d(3) pipe_bn = deepcopy(bn) pipe = Pipe( - nn.Sequential(pipe_bn), chunks=2, checkpoint=checkpoint, deferred_batch_norm=True + nn.Sequential(pipe_bn), balance=[1], devices=["cpu"], chunks=2, checkpoint=checkpoint, deferred_batch_norm=True ) x = torch.rand(4, 3, 10, 10) @@ -375,7 +413,7 @@ def test_deferred_batch_norm_params(checkpoint): bn = nn.BatchNorm2d(3) pipe_bn = deepcopy(bn) pipe = Pipe( - nn.Sequential(pipe_bn), chunks=1, checkpoint=checkpoint, deferred_batch_norm=True + nn.Sequential(pipe_bn), balance=[1], devices=["cpu"], chunks=1, checkpoint=checkpoint, deferred_batch_norm=True ) x = torch.rand(4, 3, 10, 10) @@ -395,8 +433,10 @@ def test_devices(): c = nn.Linear(1, 1) # There are extra two devices. + devices = ["cpu", "cpu", "cpu", "cpu", "cpu"] + model = nn.Sequential(a, b, c) - model = Pipe(model) + model = Pipe(model, [1, 1, 1], devices=devices) cpu = torch.device("cpu") # Extra devices must be discarded. @@ -408,7 +448,7 @@ def test_partitions(): b = nn.Linear(1, 1) model = nn.Sequential(a, b) - model = Pipe(model) + model = Pipe(model, [1, 1], devices=["cpu", "cpu"]) assert isinstance(model.partitions, nn.ModuleList) assert isinstance(model.partitions[0], nn.Sequential) @@ -422,7 +462,7 @@ def test_deny_moving(): b = nn.Linear(1, 1) model = nn.Sequential(a, b) - model = Pipe(model) + model = Pipe(model, [1, 1], devices=["cpu", "cpu"]) # Moving is denied. with pytest.raises(TypeError): @@ -458,7 +498,7 @@ def test_deny_moving(): def test_empty_module(): # Empty sequential module is not illegal. model = nn.Sequential() - model = Pipe(model) + model = Pipe(model, []) assert model(torch.tensor(42)) == torch.tensor(42) assert model((torch.tensor(42),)) == (torch.tensor(42),) @@ -473,7 +513,7 @@ def test_named_children(): b = nn.Linear(1, 1) model = nn.Sequential(OrderedDict([("a", a), ("b", b)])) - model = Pipe(model) + model = Pipe(model, [1, 1], devices=["cpu", "cpu"]) names = set(n for n, _ in model.named_modules()) assert "partitions.0.a" in names @@ -485,9 +525,23 @@ def test_named_children(): model.a +def test_recommend_auto_balance(): + with pytest.raises(ValueError, match="torch.distributed._pipeline.sync.balance"): + # balance is required + Pipe(nn.Sequential()) + + with pytest.raises(ValueError, match="torch.distributed._pipeline.sync.balance"): + # module and sum of balance have differen length (module: 0, sum of balance: 1) + Pipe(nn.Sequential(), [1]) + + with pytest.raises(ValueError, match="torch.distributed._pipeline.sync.balance"): + # module and sum of balance have different length (module: 2, sum of balance: 1) + Pipe(nn.Sequential(nn.Linear(1, 1), nn.Linear(1, 1)), [1]) + + def test_verify_module_non_sequential(): with pytest.raises(TypeError, match="module must be nn.Sequential to be partitioned"): - Pipe(nn.Module()) + Pipe(nn.Module(), [1]) def test_verify_module_duplicate_children(): @@ -495,45 +549,22 @@ def test_verify_module_duplicate_children(): model = nn.Sequential(conv, conv) with pytest.raises(ValueError, match="module with duplicate children is not supported"): - Pipe(model) + Pipe(model, [1, 1]) @skip_if_no_cuda -def test_verify_module_params_on_same_device(): +def test_verify_module_duplicate_parameters_on_distinct_devices(): class Surrogate(nn.Module): - def __init__(self, param1, param2): + def __init__(self, module): super().__init__() - self.param1 = param1 - self.param2 = param2 - - conv1 = nn.Conv2d(3, 3, 1) - conv2 = nn.Conv2d(3, 3, 1) - model = nn.Sequential(Surrogate(conv1, conv2.cuda())) - - with pytest.raises( - ValueError, - match='should have all parameters on a single device, please use .to\(\)' - ' to place the module on a single device'): - Pipe(model) - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") -@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Need atleast two GPUs") -def test_verify_nested_modules(): - model = nn.Sequential( - nn.Sequential( - nn.Linear(32, 16).cuda(0), - nn.Linear(16, 8).cuda(0) - ), - nn.Sequential( - nn.Linear(8, 4).cuda(1), - nn.Linear(4, 2).cuda(1) - ), - ) + self.module = module + + conv = nn.Conv2d(3, 3, 1) + model = nn.Sequential(Surrogate(conv), Surrogate(conv)) + + with pytest.raises(ValueError, match="module with duplicate parameters on distinct devices is not supported"): + Pipe(model, [1, 1], devices=["cpu", "cuda"]) - pipe = Pipe(model) - out = pipe(torch.rand(10, 32).cuda(0)) - assert out.device == torch.device("cuda:1") - assert out.size() == torch.Size([10, 2]) def test_verify_module_duplicate_parameters_on_same_device(): class Surrogate(nn.Module): @@ -544,7 +575,7 @@ def __init__(self, module): conv = nn.Conv2d(3, 3, 1) model = nn.Sequential(Surrogate(conv), Surrogate(conv)) - Pipe(model) + Pipe(model, [1, 1], devices=["cpu", "cpu"]) def test_forward_lockstep(): @@ -566,7 +597,7 @@ def forward(self, x): return x model = nn.Sequential(DelayedLog(0, seconds=0), DelayedLog(1, seconds=0.1)) - model = Pipe(model, chunks=3) + model = Pipe(model, balance=[1, 1], devices=["cpu", "cpu"], chunks=3) model(torch.rand(3, 1)) # Expected timeline: (Logs are recorded at !) diff --git a/test/distributed/_pipeline/sync/test_transparency.py b/test/distributed/_pipeline/sync/test_transparency.py index 3d2c77e8fef4..88d9c83b9a07 100644 --- a/test/distributed/_pipeline/sync/test_transparency.py +++ b/test/distributed/_pipeline/sync/test_transparency.py @@ -31,7 +31,7 @@ def zero_grad(parameters): zero_grad(model.parameters()) # With Pipe - model = Pipe(model, chunks=4) + model = Pipe(model, [2, 2], devices=["cpu", "cpu"], chunks=4) outputs = model(inputs) loss = outputs.mean() diff --git a/torch/distributed/_pipeline/sync/pipe.py b/torch/distributed/_pipeline/sync/pipe.py index 68906958cc0e..500b15b72771 100644 --- a/torch/distributed/_pipeline/sync/pipe.py +++ b/torch/distributed/_pipeline/sync/pipe.py @@ -65,8 +65,8 @@ def verify_module(module: nn.Sequential) -> None: raise ValueError("module with duplicate children is not supported") -def _verify_splitting( - module: nn.Sequential, partitions: List[nn.Sequential], devices: List[torch.device] +def verify_splitting( + module: nn.Sequential, partitions: List[nn.Sequential], balance: Iterable[int], devices: List[torch.device] ) -> None: num_parameters = len(list(module.parameters())) num_child_parameters = sum(len(list(child.parameters())) for child in module.children()) @@ -89,46 +89,66 @@ class BalanceError(ValueError): pass -def _retrieve_device(module: nn.Module) -> torch.device: - """Validates all parameters in the Module have the same device and returns - the appropriate device. - - Arguments: - An ``nn.Module`` to process. +def split_module( + module: nn.Sequential, balance: Iterable[int], devices: List[torch.device], +) -> Tuple[List[nn.Sequential], List[int], List[torch.device]]: + """Splits a module into multiple partitions. Returns: - ``torch.Device`` for the entire module. + A tuple of (partitions, balance, devices). + + Partitions are represented as a :class:`~torch.nn.ModuleList` whose + item is a partition. All layers in a partition are placed in the + same device. Raises: - ValueError: - If devices for ``nn.Module`` parameters are not all same. + BalanceError: + wrong balance + IndexError: + the number of devices is fewer than the number of partitions. + """ + balance = list(balance) - device = None - for parameter in module.parameters(): - if device is None: - device = parameter.device - elif device != parameter.device: - raise ValueError( - 'nn.Module: {}, should have all parameters on a single device,' - ' please use .to() to place the module on a single device'.format(module)) + if len(module) != sum(balance): + raise BalanceError( + "module and sum of balance have different length " + f"(module: {len(module)}, sum of balance: {sum(balance)})" + ) - return device if device is not None else torch.device("cpu") + if any(x <= 0 for x in balance): + raise BalanceError(f"all balance numbers must be positive integer (balance: {balance})") -def _split_module(modules: nn.Sequential) -> Tuple[List[nn.Sequential], List[torch.device]]: + if len(balance) > len(devices): + raise IndexError( + "too few devices to hold given partitions " f"(devices: {len(devices)}, partitions: {len(balance)})" + ) + + j = 0 partitions = [] - devices = [] - for name, module in modules.named_children(): - devices.append(_retrieve_device(module)) - if isinstance(module, nn.Sequential): - partition = module - else: - partition = nn.Sequential(OrderedDict([(name, module)])) - partitions.append(partition) + layers: NamedModules = OrderedDict() + + for name, layer in module.named_children(): + layers[name] = layer + + if len(layers) == balance[j]: + # Group buffered layers as a partition. + partition = nn.Sequential(layers) + + device = devices[j] + partition.to(device) + + partitions.append(partition) + + # Prepare for the next partition. + layers.clear() + j += 1 partitions = cast(List[nn.Sequential], nn.ModuleList(partitions)) + del devices[j:] + + return partitions, balance, devices - return partitions, devices MOVING_DENIED = TypeError("denied to move parameters and buffers, " "because Pipe should manage device placement") @@ -140,27 +160,28 @@ class Pipe(Module): :: model = nn.Sequential(a, b, c, d) - model = Pipe(model, chunks=8) + model = Pipe(model, balance=[1, 1, 1, 1], chunks=8) output = model(input) - .. _Pipe: https://arxiv.org/abs/2004.09910 + .. _Pipe: https://arxiv.org/abs/1811.06965 Pipe combines pipeline parallelism with checkpointing to reduce peak memory required to train while minimizing device under-utilization. - You should place all the modules on the appropriate devices before passing - them to this API and wrap them into an ``nn.Sequential`` module defining the - desired order of execution. + You should determine the balance when defining a :class:`Pipe` module, as + balancing will not be done automatically. The module will be partitioned + into multiple devices according to the given balance. You may rely on + heuristics to find your own optimal configuration. Args: - module (``torch.nn.Sequential``): - Sequential module to be parallelized using pipelining. Each module - in the sequence has to have all of its parameters on a single - device. Each module in the sequence has to either be an nn.Module - or ``nn.Sequential`` (to combine multiple sequential modules on a single - device) + module (torch.nn.Sequential): + sequential module to be parallelized + balance (ints): + list of number of layers in each partition Keyword Args: + devices (iterable of devices): + devices to use (default: all CUDA devices) chunks (int): number of micro-batches (default: ``1``) checkpoint (str): @@ -175,12 +196,33 @@ class Pipe(Module): TypeError: the module is not a :class:`nn.Sequential `. ValueError: - invalid arguments + invalid arguments, or wrong balance IndexError: the number of devices is fewer than the number of partitions. """ + #: The number of layers in each partition. + balance: List[int] = [] + # ^^ + # The default value [] required for Sphinx's autoattribute. + + #: The devices mapped to each partition. + #: + #: ``devices[-1]`` refers to the device of the last partition, which means + #: it is the output device. Probably, you need to use it to transfer the + #: target to calculate the loss without a device mismatch + #: :exc:`RuntimeError`. For example:: + #: + #: out_device = pipe.devices[-1] + #: + #: for input, target in loader: + #: target = target.to(out_device, non_blocking=True) + #: output = pipe(input) + #: loss = F.cross_entropy(output, target) + #: + devices: List[torch.device] = [] + #: The number of micro-batches. chunks: int = 1 @@ -191,6 +233,9 @@ class Pipe(Module): def __init__( self, module: nn.Sequential, + balance: Optional[Iterable[int]] = None, + *, + devices: Optional[Devices] = None, chunks: int = chunks, checkpoint: str = checkpoint, deferred_batch_norm: bool = False, @@ -200,6 +245,8 @@ def __init__( chunks = int(chunks) checkpoint = str(checkpoint) + if balance is None: + raise ValueError(recommend_auto_balance("balance is required")) if chunks <= 0: raise ValueError("number of chunks must be positive integer") if checkpoint not in ["always", "except_last", "never"]: @@ -217,8 +264,17 @@ def __init__( if deferred_batch_norm: module = DeferredBatchNorm.convert_deferred_batch_norm(module, chunks) - self.partitions, self.devices = _split_module(module) - _verify_splitting(module, self.partitions, self.devices) + if devices is None: + devices = range(torch.cuda.device_count()) + devices = [torch.device(d) for d in devices] + devices = cast(List[torch.device], devices) + + try: + self.partitions, self.balance, self.devices = split_module(module, balance, devices) + except BalanceError as exc: + raise ValueError(recommend_auto_balance(str(exc))) + + verify_splitting(module, self.partitions, self.balance, self.devices) self._copy_streams: List[List[AbstractStream]] = [] self._skip_layout = inspect_skip_layout(self.partitions) diff --git a/torch/testing/_internal/distributed/pipeline/__init__.py b/torch/testing/_internal/distributed/pipeline/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/torch/testing/_internal/distributed/pipeline/utils.py b/torch/testing/_internal/distributed/pipeline/utils.py deleted file mode 100644 index 2bf4829b8223..000000000000 --- a/torch/testing/_internal/distributed/pipeline/utils.py +++ /dev/null @@ -1,21 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. - -from torch import nn -from typing import List - -def convert_to_balance(pipe: nn.Sequential, balance: List[int]): - device_idx = 0 - pipe_idx = 0 - balanced_pipe = [] - for num_layers in balance: - layers = [] - for i in range(num_layers): - layers.append(pipe[pipe_idx]) - pipe_idx += 1 - balanced_pipe.append(nn.Sequential(*layers).to(device_idx)) - device_idx += 1 - - return nn.Sequential(*balanced_pipe) From eab809377df040308525946fdeb40866b21a52f0 Mon Sep 17 00:00:00 2001 From: Nick Gibson Date: Thu, 12 Nov 2020 20:15:47 -0800 Subject: [PATCH 87/93] [NNC] Remove all deferred expansion from Reductions (#47709) Summary: Refactors the ReduceOp node to remove the last remaining deferred functionality: completing the interaction between the accumulator buffer and the body. This fixes two issues with reductions: 1. Nodes inside the interaction could not be visited or modified, meaning we could generate bad code when the interaction was complex. 2. The accumulator load was created at expansion time and so could not be modified in some ways (ie. vectorization couldn't act on these loads). This simplifies reduction logic quite a bit, but theres a bit more involved in the rfactor transform. Pull Request resolved: https://github.com/pytorch/pytorch/pull/47709 Reviewed By: ZolotukhinM Differential Revision: D24904220 Pulled By: nickgg fbshipit-source-id: 159e5fd967d2d1f8697cfa96ce1bb5fc44920a40 --- test/cpp/tensorexpr/test_memdependency.cpp | 12 +- tools/build_variables.bzl | 1 + torch/csrc/jit/tensorexpr/ir_mutator.cpp | 8 +- torch/csrc/jit/tensorexpr/ir_printer.cpp | 3 +- torch/csrc/jit/tensorexpr/ir_visitor.cpp | 2 +- torch/csrc/jit/tensorexpr/loopnest.cpp | 102 +++++++++--- .../jit/tensorexpr/mem_dependency_checker.cpp | 45 ------ .../jit/tensorexpr/mem_dependency_checker.h | 1 - torch/csrc/jit/tensorexpr/reduction.cpp | 37 +++++ torch/csrc/jit/tensorexpr/reduction.h | 149 +++++++++--------- torch/csrc/jit/tensorexpr/var_substitutor.h | 8 +- 11 files changed, 209 insertions(+), 159 deletions(-) create mode 100644 torch/csrc/jit/tensorexpr/reduction.cpp diff --git a/test/cpp/tensorexpr/test_memdependency.cpp b/test/cpp/tensorexpr/test_memdependency.cpp index 7b866747ca05..9dae18476bb1 100644 --- a/test/cpp/tensorexpr/test_memdependency.cpp +++ b/test/cpp/tensorexpr/test_memdependency.cpp @@ -575,11 +575,15 @@ void testMemDependencyCheckerLoopReduce() { // The loop contents depend on the initializer too. ASSERT_TRUE(analyzer.dependsDirectly(loop, aInit)); + // Find loads within the reduction: + auto reduceLoads = NodeFinder::find(reduce.node()); // Pull out the access for the load inside the loop. - auto loopLoad = analyzer.accessFor(reduce.node()); - // It should have 10 element long bounds. - ASSERT_TRUE(indexBoundsEquals( - loopLoad->bounds(), {Bound(new IntImm(0), new IntImm(9))})); + for (auto* load : reduceLoads) { + auto loopLoad = analyzer.accessFor(load); + // It should have 10 element long bounds. + ASSERT_TRUE(indexBoundsEquals( + loopLoad->bounds(), {Bound(new IntImm(0), new IntImm(9))})); + } } // Lowering a reduction doesn't affect dependency analysis. diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl index e2d4449f0a64..60e0562e20e5 100644 --- a/tools/build_variables.bzl +++ b/tools/build_variables.bzl @@ -242,6 +242,7 @@ core_sources_full = [ "torch/csrc/jit/tensorexpr/block_codegen.cpp", "torch/csrc/jit/tensorexpr/loopnest.cpp", "torch/csrc/jit/tensorexpr/mem_arena.cpp", + "torch/csrc/jit/tensorexpr/reduction.cpp", "torch/csrc/jit/tensorexpr/registerizer.cpp", "torch/csrc/jit/tensorexpr/tensor.cpp", "torch/csrc/jit/tensorexpr/types.cpp", diff --git a/torch/csrc/jit/tensorexpr/ir_mutator.cpp b/torch/csrc/jit/tensorexpr/ir_mutator.cpp index d3a0cc45d27f..5f0889842b1e 100644 --- a/torch/csrc/jit/tensorexpr/ir_mutator.cpp +++ b/torch/csrc/jit/tensorexpr/ir_mutator.cpp @@ -270,7 +270,7 @@ const Expr* IRMutator::mutate(const MinTerm* v) { const Expr* IRMutator::mutate(const ReduceOp* v) { const Expr* buf_new_expr = v->accumulator()->accept_mutator(this); const Buf* buf_new = dynamic_cast(buf_new_expr); - auto body = v->body().node()->accept_mutator(this); + const Expr* body_new = v->body()->accept_mutator(this); std::vector new_output_args; std::vector new_reduce_args; @@ -282,11 +282,7 @@ const Expr* IRMutator::mutate(const ReduceOp* v) { } return new ReduceOp( - buf_new, - ExprHandle(body), - v->interaction(), - new_output_args, - new_reduce_args); + buf_new, body_new, new_output_args, new_reduce_args, v->reducer()); } const Expr* IRMutator::mutate(const BaseCallNode* v) { diff --git a/torch/csrc/jit/tensorexpr/ir_printer.cpp b/torch/csrc/jit/tensorexpr/ir_printer.cpp index ef8135c6887c..848bd70cf5c7 100644 --- a/torch/csrc/jit/tensorexpr/ir_printer.cpp +++ b/torch/csrc/jit/tensorexpr/ir_printer.cpp @@ -349,8 +349,7 @@ void IRPrinter::visit(const MinTerm* v) { void IRPrinter::visit(const ReduceOp* v) { os() << "ReduceOp("; - os() << *v->accumulator() << ", "; - os() << v->complete() << ", "; + os() << *v->body() << ", "; bool first = true; os() << "out_args={"; diff --git a/torch/csrc/jit/tensorexpr/ir_visitor.cpp b/torch/csrc/jit/tensorexpr/ir_visitor.cpp index 6d2b2140d5b3..ae97a6200d8b 100644 --- a/torch/csrc/jit/tensorexpr/ir_visitor.cpp +++ b/torch/csrc/jit/tensorexpr/ir_visitor.cpp @@ -227,7 +227,7 @@ void IRVisitor::visit(const MinTerm* v) { void IRVisitor::visit(const ReduceOp* v) { v->accumulator()->accept(this); - v->body().node()->accept(this); + v->body()->accept(this); for (auto* e : v->output_args()) { e->accept(this); diff --git a/torch/csrc/jit/tensorexpr/loopnest.cpp b/torch/csrc/jit/tensorexpr/loopnest.cpp index 10ac09e61ac9..ada44d71b8cf 100644 --- a/torch/csrc/jit/tensorexpr/loopnest.cpp +++ b/torch/csrc/jit/tensorexpr/loopnest.cpp @@ -1612,7 +1612,7 @@ class CacheReplacer : public IRMutator { return IRMutator::mutate(v); } - const Expr* newBody = v->body().node()->accept_mutator(this); + const Expr* newBody = v->body()->accept_mutator(this); // Map indices to call-parameters. std::vector newIndices; @@ -1625,11 +1625,7 @@ class CacheReplacer : public IRMutator { } return new ReduceOp( - cache_, - ExprHandle(newBody), - v->interaction(), - newIndices, - v->reduce_args()); + cache_, newBody, newIndices, v->reduce_args(), v->reducer()); } const Buf* buf_; @@ -1739,10 +1735,9 @@ LoopNest::AccessResult LoopNest::cacheAccesses( Stmt* tmp_store = new Store( producer, tmp_params, - new ReduceOp( + reduceOp->reducer()( producer, ExprHandle(new Load(tmp_buf, new_loop_vars_expr, new IntImm(1))), - reduceOp->interaction(), tmp_params, {}), new IntImm(1)); @@ -2030,6 +2025,70 @@ class StoreFinder : public IRVisitor { const Store* store_; }; +class BufReplacer : public IRMutator { + public: + BufReplacer( + const Buf* old_buf, + const std::vector& old_indices, + const Buf* new_buf, + const std::vector& new_indices) + : old_buf_(old_buf), + old_indices_(old_indices), + new_buf_(new_buf), + new_indices_(new_indices) {} + + const Expr* mutate(const Load* v) override { + if (v->buf() != old_buf_) { + return IRMutator::mutate(v); + } + + TORCH_INTERNAL_ASSERT(old_indices_.size() == v->indices().size()); + + bool equal_indices = true; + for (size_t i = 0; i < v->indices().size(); ++i) { + if (!exprEquals(v->indices()[i], old_indices_[i])) { + equal_indices = false; + break; + } + } + if (!equal_indices) { + return IRMutator::mutate(v); + } + + const Expr* mask_new = v->mask()->accept_mutator(this); + return new Load(new_buf_, new_indices_, mask_new); + } + + Stmt* mutate(const Store* v) override { + if (v->buf() != old_buf_) { + return IRMutator::mutate(v); + } + + TORCH_INTERNAL_ASSERT(old_indices_.size() == v->indices().size()); + + bool equal_indices = true; + for (size_t i = 0; i < v->indices().size(); ++i) { + if (!exprEquals(v->indices()[i], old_indices_[i])) { + equal_indices = false; + break; + } + } + if (!equal_indices) { + return IRMutator::mutate(v); + } + + const Expr* new_value = v->value()->accept_mutator(this); + const Expr* mask_new = v->mask()->accept_mutator(this); + return new Store(new_buf_, new_indices_, new_value, mask_new); + } + + private: + const Buf* old_buf_; + const std::vector& old_indices_; + const Buf* new_buf_; + const std::vector& new_indices_; +}; + void LoopNest::rfactor( const Expr* r, const Var* reduction_var, @@ -2102,7 +2161,7 @@ void LoopNest::rfactor( std::vector new_dims = {}; Buf* tmp_buf = - new Buf(new Var("tmp_buf", kHandle), new_dims, reduce_op->body().dtype()); + new Buf(new Var("tmp_buf", kHandle), new_dims, reduce_op->dtype()); auto old_acc = reduce_op->accumulator(); auto new_inner = reduce_op->reduce_args(); @@ -2130,26 +2189,19 @@ void LoopNest::rfactor( } new_outer.emplace_back(reduction_var); + BufReplacer bufReplacer( + reduce_op->accumulator(), reduce_op->output_args(), tmp_buf, new_outer); + const Expr* new_body = reduce_op->body()->accept_mutator(&bufReplacer); + auto first_reduce = new ReduceOp( - tmp_buf, - reduce_op->body(), - reduce_op->interaction(), - new_outer, - new_inner); + tmp_buf, new_body, new_outer, new_inner, reduce_op->reducer()); auto second_reduce_load_indices = reduce_op->output_args(); second_reduce_load_indices.emplace_back(reduction_var); - auto second_reduce_load = ExprHandle(new Load( - reduce_op->body().dtype(), - tmp_buf, - second_reduce_load_indices, - new IntImm(1))); - auto second_reduce = new ReduceOp( - old_acc, - second_reduce_load, - reduce_op->interaction(), - reduce_op->output_args(), - {reduction_var}); + auto second_reduce_load = new Load( + reduce_op->dtype(), tmp_buf, second_reduce_load_indices, new IntImm(1)); + auto second_reduce = reduce_op->reducer()( + old_acc, second_reduce_load, reduce_op->output_args(), {reduction_var}); // 1) replace target for loop (which is a reduction loop) // with an iterative for loop by removing the reduction var from the diff --git a/torch/csrc/jit/tensorexpr/mem_dependency_checker.cpp b/torch/csrc/jit/tensorexpr/mem_dependency_checker.cpp index 49faa865c612..2b9df8ac9b56 100644 --- a/torch/csrc/jit/tensorexpr/mem_dependency_checker.cpp +++ b/torch/csrc/jit/tensorexpr/mem_dependency_checker.cpp @@ -609,51 +609,6 @@ void MemDependencyChecker::visit(const FunctionCall* v) { currentScope_->accesses_.push_back(call); } -void MemDependencyChecker::visit(const ReduceOp* v) { - auto indicesScope = - std::make_shared(currentScope_->block, currentScope_); - currentScope_ = indicesScope; - - for (const Expr* ind : v->output_args()) { - ind->accept(this); - } - - const Var* var = v->accumulator()->base_handle(); - - // ReduceOps are functionally Loads, and the distinction isn't meaningful so - // just record them as Loads. They get lowered directly to load during - // prepareForCodegen anyway. - auto load = std::make_shared( - nextAccess_++, - AccessType::Load, - v, - lastStmt_, - var, - getIndicesBounds(v->output_args())); - - // If there were loads in the output_args, this call depends on them, also - // merge. - if (!indicesScope->accesses_.empty()) { - for (auto& access : indicesScope->accesses_) { - load->addDependency(access); - access->addDependent(load); - } - mergeScope(indicesScope, indicesScope->parent, false); - } - - stmtToAccess_.emplace(lastStmt_, load); - exprToAccess_.emplace(v, load); - - // Intentionally using operator[], we want it to be created if it does not - // exist. - auto& writeHistory = currentScope_->openWrites_[var]; - updateWriteHistory(writeHistory, load, load->id()); - currentScope_->accesses_.push_back(load); - - // accept the body of the reduction to handle further reads. - v->body().node()->accept(this); -} - // This check determines if two accesses within a loop are "safe" from loop-self // dependence. This function does not consider overlap in bound range, but // rather the stride of the bound relative to the loop variable. This is the diff --git a/torch/csrc/jit/tensorexpr/mem_dependency_checker.h b/torch/csrc/jit/tensorexpr/mem_dependency_checker.h index a1bb91fa17ad..24a0bfd9153c 100644 --- a/torch/csrc/jit/tensorexpr/mem_dependency_checker.h +++ b/torch/csrc/jit/tensorexpr/mem_dependency_checker.h @@ -248,7 +248,6 @@ class TORCH_API MemDependencyChecker : public IRVisitor { void visit(const Store* v) override; void visit(const Load* v) override; void visit(const FunctionCall* v) override; - void visit(const ReduceOp* v) override; void visit(const For* v) override; void visit(const Cond* v) override; void visit(const IfThenElse* v) override; diff --git a/torch/csrc/jit/tensorexpr/reduction.cpp b/torch/csrc/jit/tensorexpr/reduction.cpp new file mode 100644 index 000000000000..a3daeaa808a3 --- /dev/null +++ b/torch/csrc/jit/tensorexpr/reduction.cpp @@ -0,0 +1,37 @@ + +#include +#include + +namespace torch { +namespace jit { +namespace tensorexpr { + +ReduceOp* Reducer::operator()( + const Buf* result_buf, + ExprHandle body, + const std::vector& output, + const std::vector& inner) const { + return new ReduceOp( + result_buf, + complete(result_buf, interaction_, body, output, inner), + output, + inner, + *this); +} + +ReduceOp* Reducer::operator()( + const Buf* result_buf, + const Expr* body, + const std::vector& output, + const std::vector& inner) const { + return new ReduceOp( + result_buf, + complete(result_buf, interaction_, ExprHandle(body), output, inner), + output, + inner, + *this); +} + +} // namespace tensorexpr +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/tensorexpr/reduction.h b/torch/csrc/jit/tensorexpr/reduction.h index b1d335302cb7..40fd58b0cd18 100644 --- a/torch/csrc/jit/tensorexpr/reduction.h +++ b/torch/csrc/jit/tensorexpr/reduction.h @@ -4,7 +4,6 @@ #include #include #include -#include #include #include @@ -17,76 +16,12 @@ namespace tensorexpr { using ParameterList = const std::vector; using ReduceInteraction = std::function; -// An expression representing a Reduction operation (e.g. Sum, Max) broken into -// it's component parts: initialization, accumulation var, acquisition of value -// to be reduced and interaction. -// -// This is intended to be expanded in the loopnest and not make it to codegen. -class ReduceOp : public ExprNode { - public: - ReduceOp( - const Buf* accum, - ExprHandle body, - ReduceInteraction c, - const std::vector& output_args, - const std::vector& reduce_args) - : ExprNodeBase(body.dtype()), - accumulator_(accum), - body_(body), - interaction_(c), - output_args_(output_args), - reduce_args_(reduce_args) {} - - // return the accumulation load expression. - const Buf* accumulator() const { - return accumulator_; - } - - // return the body expression which obtains the value to be reduced. - ExprHandle body() const { - return body_; - } - - // returns a function encoding the interaction between accumulator and the - // reduction value. - ReduceInteraction interaction() const { - return interaction_; - } - - // returns variables associated with the output Tensor. - const std::vector& output_args() const { - return output_args_; - } - - // returns variables associated with the axes of reduction. - const std::vector& reduce_args() const { - return reduce_args_; - } - - // Completes the reduction operator by applying the interaction function to - // the accumulation and the body expression. - ExprHandle complete() const { - std::vector indices(output_args_.begin(), output_args_.end()); - ExprHandle accum = ExprHandle( - new Load(body_.dtype(), accumulator_, indices, new IntImm(1))); - auto e = interaction_(accum, body_); - return e; - } - - private: - const Buf* accumulator_; - ExprHandle body_; - ReduceInteraction interaction_; - std::vector output_args_; - std::vector reduce_args_; -}; - // A Reducer is a user interface describing a particular reduction // operation. It has three components: An initialization value, a way of // interacting each value with the accumulation, and a method for obtaining the // current value to be reduced. It is materialized into a ReduceOp when loop // variables are known. -class Reducer { +class TORCH_API Reducer { public: Reducer(ExprHandle init, ReduceInteraction& interaction) : init_(init.node()), interaction_(interaction) {} @@ -98,6 +33,7 @@ class Reducer { Reducer(ExprHandle init, RI interaction) : init_(init.node()) { interaction_ = interaction; } + virtual ~Reducer() {} const Expr* initializer() const { return init_; @@ -106,10 +42,14 @@ class Reducer { ReduceOp* operator()( const Buf* result_buf, ExprHandle body, - std::vector output, - std::vector inner) const { - return new ReduceOp(result_buf, body, interaction_, output, inner); - } + const std::vector& output, + const std::vector& inner) const; + + ReduceOp* operator()( + const Buf* result_buf, + const Expr* body, + const std::vector& output, + const std::vector& inner) const; // Polymorphic handling of Body functions with a variety of parameters. static ExprHandle getReduceBody( @@ -161,11 +101,78 @@ class Reducer { return func(vars[0], vars[1], vars[2], vars[3]); } + // Completes the reduction operator by applying the interaction function to + // the accumulation and the body expression. + static Expr* complete( + const Buf* accumulator, + ReduceInteraction interaction, + ExprHandle body, + const std::vector& output_args, + const std::vector& reduce_args) { + ExprHandle accum = ExprHandle( + new Load(body.dtype(), accumulator, output_args, new IntImm(1))); + auto e = interaction(accum, body); + return e.node(); + } + private: const Expr* init_; ReduceInteraction interaction_; }; +// An expression representing a Reduction operation (e.g. Sum, Max) broken into +// it's component parts: initialization, accumulation var, acquisition of value +// to be reduced and interaction. +// +// This is intended to be expanded in the loopnest and not make it to codegen. +class ReduceOp : public ExprNode { + public: + ReduceOp( + const Buf* accum, + const Expr* body, + const std::vector& output_args, + const std::vector& reduce_args, + const Reducer& reducer) + : ExprNodeBase(body->dtype()), + accumulator_(accum), + body_(body), + output_args_(output_args), + reduce_args_(reduce_args), + reducer_(reducer) {} + + // return the accumulation load expression. + const Buf* accumulator() const { + return accumulator_; + } + + // return the body expression which obtains the value to be reduced. + const Expr* body() const { + return body_; + } + + // Returns the original Reducer factory that can create ReduceOps. + const Reducer& reducer() const { + return reducer_; + } + + // returns variables associated with the output Tensor. + const std::vector& output_args() const { + return output_args_; + } + + // returns variables associated with the axes of reduction. + const std::vector& reduce_args() const { + return reduce_args_; + } + + private: + const Buf* accumulator_; + const Expr* body_; + std::vector output_args_; + std::vector reduce_args_; + const Reducer reducer_; +}; + class Sum : public Reducer { public: Sum() @@ -232,7 +239,7 @@ class ReductionExpander : public IRMutator { } const Expr* mutate(const ReduceOp* v) override { - return v->complete().node(); + return v->body(); } }; diff --git a/torch/csrc/jit/tensorexpr/var_substitutor.h b/torch/csrc/jit/tensorexpr/var_substitutor.h index 3a02507c6dca..29e0f8de2a01 100644 --- a/torch/csrc/jit/tensorexpr/var_substitutor.h +++ b/torch/csrc/jit/tensorexpr/var_substitutor.h @@ -37,7 +37,7 @@ class VarSubMutator : public IRMutator { } const Expr* mutate(const ReduceOp* var) override { - auto body = var->body().node()->accept_mutator(this); + auto body = var->body()->accept_mutator(this); std::vector new_outer; std::vector new_inner; @@ -59,10 +59,10 @@ class VarSubMutator : public IRMutator { return new ReduceOp( const_cast(var->accumulator()), - ExprHandle(body), - var->interaction(), + body, new_outer, - new_inner); + new_inner, + var->reducer()); } private: From 85c43c3da15ab9671791e6b7a7eabaccf2eeb459 Mon Sep 17 00:00:00 2001 From: David Date: Thu, 12 Nov 2020 20:21:31 -0800 Subject: [PATCH 88/93] [ONNX] Convert _len based on the first dimension length (#47538) Summary: This PR is a bug fix. As UT shows, for multiple-dimensional tensors, the current conversion for _len returns the total number of the tensors. But it should return the first dimension length, as pytorch _len defines. Need `Squeeze` op at the end to ensure it outputs a scalar value. Pull Request resolved: https://github.com/pytorch/pytorch/pull/47538 Reviewed By: malfet Differential Revision: D24870717 Pulled By: bzinodev fbshipit-source-id: c53c745baa6d2fb7cc1de55a19bd2eedb2ad5272 --- test/onnx/test_pytorch_onnx_onnxruntime.py | 14 ++++++++++++++ torch/onnx/symbolic_opset11.py | 3 ++- torch/onnx/symbolic_opset9.py | 3 ++- 3 files changed, 18 insertions(+), 2 deletions(-) diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index 609c2e75f330..45c01dac9a4e 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -3344,6 +3344,20 @@ def forward(self, x): x = torch.randn(5, 3, 3) self.run_test(model, x) + @skipIfUnsupportedMinOpsetVersion(11) + def test_loop_multi_dim(self): + class LoopMultiDimModel(torch.jit.ScriptModule): + @torch.jit.script_method + def forward(self, x, y): + for x_ in torch.flip(x.narrow(0, 0, 7), [0]): + y = x_[0][y] + return y + + model = LoopMultiDimModel() + x = torch.randint(0, 5, (8, 1, 17), dtype=torch.long) + y = torch.ones(1, dtype=torch.long) + self.run_test(model, (x, y)) + @skipIfUnsupportedMinOpsetVersion(11) def test_list(self): class ListModel(torch.jit.ScriptModule): diff --git a/torch/onnx/symbolic_opset11.py b/torch/onnx/symbolic_opset11.py index 8c3f740d8b70..5e431664e93d 100644 --- a/torch/onnx/symbolic_opset11.py +++ b/torch/onnx/symbolic_opset11.py @@ -373,7 +373,8 @@ def masked_scatter(g, self, mask, source): def _len(g, self): if _is_tensor_list(self) or self.node().kind() == "onnx::SplitToSequence": return g.op("SequenceLength", self) - return g.op("Size", self) + sz_0 = size(g, self, g.op("Constant", value_t=torch.LongTensor([0]))) + return g.op('Squeeze', sz_0, axes_i=[0]) def __getitem_(g, self, i): diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py index d06203cb4508..068ddef9d94f 100644 --- a/torch/onnx/symbolic_opset9.py +++ b/torch/onnx/symbolic_opset9.py @@ -658,7 +658,8 @@ def floor(g, input): def _len(g, self): - return g.op("Size", self) + sz_0 = size(g, self, g.op("Constant", value_t=torch.LongTensor([0]))) + return g.op('Squeeze', sz_0, axes_i=[0]) @parse_args('v', 't', 't') From 9fa681c5e074b4f5a477430e72918c5c4a147fdc Mon Sep 17 00:00:00 2001 From: Ksenija Stanojevic Date: Thu, 12 Nov 2020 20:32:44 -0800 Subject: [PATCH 89/93] [ONNX] Add export of prim::dtype, prim::tolist (#46019) Summary: Add export of prim::dtype, prim::tolist. Pull Request resolved: https://github.com/pytorch/pytorch/pull/46019 Reviewed By: malfet Differential Revision: D24870870 Pulled By: bzinodev fbshipit-source-id: 7f59e2c8f5ac2dbf83c889c73bd61f96587a296e --- test/onnx/test_pytorch_onnx_onnxruntime.py | 36 ++++++++++++++++++++++ torch/onnx/symbolic_opset9.py | 16 ++++++++++ 2 files changed, 52 insertions(+) diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index 45c01dac9a4e..1a8965122d8e 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -3441,6 +3441,20 @@ def forward(self, x): x = torch.randn(2, 3, 4) self.run_test(Zero_(), x) + @skipIfONNXShapeInference(True) + @skipIfUnsupportedMinOpsetVersion(9) + def test_tolist(self): + class List(torch.jit.ScriptModule): + @torch.jit.script_method + def forward(self, input): + cur_shape = torch._shape_as_tensor(input) + final_shape: List[int] = cur_shape.tolist() + pad_tensor = torch.zeros([1, 2] + final_shape) + return pad_tensor + + x = torch.randn(2, 3) + self.run_test(List(), (x,)) + @skipIfUnsupportedMinOpsetVersion(9) def test_list_pass(self): class Slice(torch.nn.Module): @@ -4031,6 +4045,28 @@ def forward(self, input): model = MyModule() self.run_test(model, (x,)) + def test_dtype(self): + class MyModel(torch.jit.ScriptModule): + @torch.jit.script_method + def forward(self, input, other): + return input.to(dtype=other.dtype) + other + + x = torch.randn(2, 3) + y = torch.randn(2, 3) + self.run_test(MyModel(), (x, y)) + + def test_dtype_eq(self): + class MyModel(torch.jit.ScriptModule): + @torch.jit.script_method + def forward(self, input, other): + if input.dtype == other.dtype: + return input + other + return input + + x = torch.randn(2, 3) + y = torch.randn(2, 3) + self.run_test(MyModel(), (x, y)) + def test_cast_to(self): class MyModule(torch.jit.ScriptModule): @torch.jit.script_method diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py index 068ddef9d94f..0a1d95aff7e2 100644 --- a/torch/onnx/symbolic_opset9.py +++ b/torch/onnx/symbolic_opset9.py @@ -2254,6 +2254,22 @@ def is_floating_point(g, self): return g.op("Constant", value_t=torch.BoolTensor([0])) +def prim_dtype(g, self): + dtype = sym_help._try_get_scalar_type(self) + dtype = sym_help.scalar_type_to_onnx.index(sym_help.cast_pytorch_to_onnx[dtype]) + return g.op("Constant", value_t=torch.IntTensor([dtype])) + + +# tolist is currently supported only for 1D input tensors. +# dim_val and elem_ty_val represent dimension and type annotations +# that need to match dimension and type of the input tensor. +def prim_tolist(g, input, dim_val, elem_ty_val): + dim = sym_help._maybe_get_const(dim_val, 'i') + if dim > 1: + return _unimplemented("prim_tolist", "dim_val > 1") + return input + + @parse_args('v', 'i') def one_hot(g, self, num_classes): values = g.op("Constant", value_t=torch.LongTensor([0, 1])) From c4ecbcdcb317ffab2b4bba51915bad77b212b0b0 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Thu, 12 Nov 2020 20:42:07 -0800 Subject: [PATCH 90/93] [quant][graphmode][fx][refactor] insert_observer_for_special_module (#47783) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47783 Test Plan: python test/test_quantization.py TestQuantizeFx Imported from OSS Reviewed By: vkuzo Differential Revision: D24900304 fbshipit-source-id: 11cc3dd4ea5e272209db9f3c419deadd40db5f42 --- torch/quantization/fx/quantize.py | 52 +++++++++++++++++-------------- 1 file changed, 29 insertions(+), 23 deletions(-) diff --git a/torch/quantization/fx/quantize.py b/torch/quantization/fx/quantize.py index 8c131de22fd1..7a3721938c33 100644 --- a/torch/quantization/fx/quantize.py +++ b/torch/quantization/fx/quantize.py @@ -397,6 +397,34 @@ def insert_observer(node, observer): env[node.name] = observed_graph.create_node('call_module', observer_name, (load_arg(node),), {}) observed_node_names_set.add(node.name) + def insert_observer_for_special_module(quantize_handler): + """ Insert observer for custom module and standalone module + Returns: standalone_module_input_idxs: the indexs for inputs that needs + to be observed by parent module + """ + standalone_module_input_idxs = None + if isinstance(quantize_handler, CustomModuleQuantizeHandler): + custom_module = self.modules[node.target] + custom_module_class_mapping = prepare_custom_config_dict.get("float_to_observed_custom_module_class", {}) + observed_custom_module_class = \ + get_swapped_custom_module_class(custom_module, custom_module_class_mapping, qconfig) + observed_custom_module = \ + observed_custom_module_class.from_float(custom_module) + parent_name, name = _parent_name(node.target) + setattr(self.modules[parent_name], name, observed_custom_module) + elif isinstance(quantize_handler, StandaloneModuleQuantizeHandler): + # observe standalone module + standalone_module = self.modules[node.target] + prepare = torch.quantization.quantize_fx._prepare_standalone_module_fx + observed_standalone_module = prepare(standalone_module, {"": qconfig}) + observed_standalone_module.qconfig = qconfig + standalone_module_input_idxs = observed_standalone_module._standalone_module_observed_input_idxs + observed_standalone_module = mark_observed_standalone_module(observed_standalone_module) + parent_name, name = _parent_name(node.target) + setattr(self.modules[parent_name], name, observed_standalone_module) + self.modules[node.target] = observed_standalone_module + return standalone_module_input_idxs + result_node : Optional[Node] = None for node in model.graph.nodes: if node.op == 'output': @@ -412,30 +440,8 @@ def insert_observer(node, observer): elif root_node is node: env[node.name] = observed_graph.node_copy(node, load_arg) # index for input of custom module that needs to be observed in parent - standalone_module_input_idxs = None if qconfig is not None: - if isinstance(obj, CustomModuleQuantizeHandler): - custom_module = self.modules[node.target] - custom_module_class_mapping = prepare_custom_config_dict.get("float_to_observed_custom_module_class", {}) - observed_custom_module_class = \ - get_swapped_custom_module_class(custom_module, custom_module_class_mapping, qconfig) - observed_custom_module = \ - observed_custom_module_class.from_float(custom_module) - parent_name, name = _parent_name(node.target) - setattr(self.modules[parent_name], name, observed_custom_module) - - elif isinstance(obj, StandaloneModuleQuantizeHandler): - # observe standalone module - standalone_module = self.modules[node.target] - prepare = torch.quantization.quantize_fx._prepare_standalone_module_fx - observed_standalone_module = prepare(standalone_module, {'': qconfig}) - observed_standalone_module.qconfig = qconfig - standalone_module_input_idxs = observed_standalone_module._standalone_module_observed_input_idxs - observed_standalone_module = mark_observed_standalone_module(observed_standalone_module) - parent_name, name = _parent_name(node.target) - setattr(self.modules[parent_name], name, observed_standalone_module) - self.modules[node.target] = observed_standalone_module - + standalone_module_input_idxs = insert_observer_for_special_module(obj) # don't need to insert observer for output if activation does not # need to be statically quantized From 59e96c55f7d35be8badf35e4f54c78f4bdbaa100 Mon Sep 17 00:00:00 2001 From: Alberto Alfarano Date: Thu, 12 Nov 2020 20:49:33 -0800 Subject: [PATCH 91/93] Support MatMul in c2_pt_converter Summary: Added the MatMul operator for caffe2 Test Plan: buck test //caffe2/torch/fb/model_transform/c2_convert:c2_pt_converter_test Reviewed By: bugra Differential Revision: D24920937 fbshipit-source-id: 7ba09ba0439cb9bd15d6a41fd8ff1a86d8d11437 --- caffe2/python/brew.py | 1 + caffe2/python/helpers/algebra.py | 5 +++++ 2 files changed, 6 insertions(+) diff --git a/caffe2/python/brew.py b/caffe2/python/brew.py index 3e7b3212a41c..05663fbee336 100644 --- a/caffe2/python/brew.py +++ b/caffe2/python/brew.py @@ -66,6 +66,7 @@ class HelperWrapper(object): 'add_weight_decay': add_weight_decay, 'elementwise_linear': elementwise_linear, 'layer_norm': layer_norm, + 'mat_mul' : mat_mul, 'batch_mat_mul' : batch_mat_mul, 'cond' : cond, 'loop' : loop, diff --git a/caffe2/python/helpers/algebra.py b/caffe2/python/helpers/algebra.py index 0ce264730f45..e991591305ac 100644 --- a/caffe2/python/helpers/algebra.py +++ b/caffe2/python/helpers/algebra.py @@ -23,6 +23,11 @@ def sub(model, blob_in, blob_out, **kwargs): return model.net.Sub(blob_in, blob_out, **kwargs) +def mat_mul(model, blob_in, blob_out, **kwargs): + """Matrix multiplication""" + return model.net.MatMul(blob_in, blob_out, **kwargs) + + def batch_mat_mul(model, blob_in, blob_out, enable_tensor_core=False, **kwargs): if enable_tensor_core: From 1afdcbfbb34c7885b5cf82e0de52d6a123e71bfa Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Thu, 12 Nov 2020 21:37:00 -0800 Subject: [PATCH 92/93] [quant][graphmode][fx][refactor] insert_observer_for_output_of_the_node (#47784) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47784 Test Plan: python test/test_quantization.py TestQuantizeFx Imported from OSS Reviewed By: vkuzo Differential Revision: D24900301 fbshipit-source-id: abaeae1b5747e517adeb0d50cec5998a8a3fc24d --- torch/quantization/fx/quantize.py | 125 ++++++++++++++++-------------- 1 file changed, 68 insertions(+), 57 deletions(-) diff --git a/torch/quantization/fx/quantize.py b/torch/quantization/fx/quantize.py index 7a3721938c33..4a01c0baa8ca 100644 --- a/torch/quantization/fx/quantize.py +++ b/torch/quantization/fx/quantize.py @@ -425,6 +425,72 @@ def insert_observer_for_special_module(quantize_handler): self.modules[node.target] = observed_standalone_module return standalone_module_input_idxs + def insert_observer_for_output_of_the_node( + node, + quantize_handler, + qconfig, + standalone_module_input_idxs): + """ Insert observer/fake_quantize module for output of the observed module + if needed + """ + # don't need to insert observer for output if activation does not + # need to be statically quantized + if activation_is_statically_quantized(qconfig): + if isinstance(quantize_handler, FixedQParamsOpQuantizeHandler) and model.training: + # we only insert fake quantize module in qat + activation_post_process_ctr = \ + get_default_output_activation_post_process_map().get(pattern, None) + assert activation_post_process_ctr is not None, \ + "activation_post_process constructor not provided for " + \ + "pattern:" + str(pattern) + insert_observer(node, activation_post_process_ctr()) + elif (isinstance(quantize_handler, FixedQParamsOpQuantizeHandler) and + not model.training) or isinstance(quantize_handler, CopyNode): + # inserting observers for output of observed module, or mark the output + # as observed + assert node.op in [ + 'call_module', + 'call_function', + 'call_method'], \ + 'CopyNode of type ' + node.op + ' is not handled' + + def is_observed(input_arg): + if isinstance(input_arg, Node): + return input_arg.name in observed_node_names_set + elif isinstance(input_arg, list): + return all(map(is_observed, input_arg)) + # propagate observed property from input + if is_observed(node.args[0]): + observed_node_names_set.add(node.name) + elif ((isinstance(quantize_handler, Add) or isinstance(quantize_handler, Mul)) and + quantize_handler.num_node_args == 1): + input_node = matched_nodes[-1] # first node in the sequence + + def input_is_observed(arg): + return isinstance(arg, Node) and arg.name in observed_node_names_set + # This is checking if one of the argument of add/mul + # is an observed node + # If both of the inputs are number, + # we will not consider the output to be observed + if input_is_observed(input_node.args[0]) or input_is_observed(input_node.args[1]): + observed_node_names_set.add(node.name) + elif isinstance(quantize_handler, StandaloneModuleQuantizeHandler): + assert node.op == 'call_module' + output_is_observed = self.modules[node.target]._output_is_observed + if output_is_observed: + observed_node_names_set.add(node.name) + elif quantize_handler.all_node_args: + # observer for outputs + new_observer = qconfig.activation() + insert_observer(node, new_observer) + + # insert observer for input of standalone module + if standalone_module_input_idxs is not None: + for idx in standalone_module_input_idxs: + if node.args[idx].name not in observed_node_names_set: + new_observer = qconfig.activation() + insert_observer(node.args[idx], new_observer) + result_node : Optional[Node] = None for node in model.graph.nodes: if node.op == 'output': @@ -442,63 +508,8 @@ def insert_observer_for_special_module(quantize_handler): # index for input of custom module that needs to be observed in parent if qconfig is not None: standalone_module_input_idxs = insert_observer_for_special_module(obj) - - # don't need to insert observer for output if activation does not - # need to be statically quantized - if activation_is_statically_quantized(qconfig): - if isinstance(obj, FixedQParamsOpQuantizeHandler) and model.training: - # we only insert fake quantize module in qat - activation_post_process_ctr = \ - get_default_output_activation_post_process_map().get(pattern, None) - assert activation_post_process_ctr is not None, \ - "activation_post_process constructor not provided for " + \ - "pattern:" + str(pattern) - insert_observer(node, activation_post_process_ctr()) - elif (isinstance(obj, FixedQParamsOpQuantizeHandler) and - not model.training) or isinstance(obj, CopyNode): - # inserting observers for output of observed module, or mark the output - # as observed - assert node.op in [ - 'call_module', - 'call_function', - 'call_method'], \ - 'CopyNode of type ' + node.op + ' is not handled' - - def is_observed(input_arg): - if isinstance(input_arg, Node): - return input_arg.name in observed_node_names_set - elif isinstance(input_arg, list): - return all(map(is_observed, input_arg)) - # propagate observed property from input - if is_observed(node.args[0]): - observed_node_names_set.add(node.name) - elif (isinstance(obj, Add) or isinstance(obj, Mul)) and obj.num_node_args == 1: - input_node = matched_nodes[-1] # first node in the sequence - - def input_is_observed(arg): - return isinstance(arg, Node) and arg.name in observed_node_names_set - # This is checking if one of the argument of add/mul - # is an observed node - # If both of the inputs are number, - # we will not consider the output to be observed - if input_is_observed(input_node.args[0]) or input_is_observed(input_node.args[1]): - observed_node_names_set.add(node.name) - elif isinstance(obj, StandaloneModuleQuantizeHandler): - assert node.op == 'call_module' - output_is_observed = self.modules[node.target]._output_is_observed - if output_is_observed: - observed_node_names_set.add(node.name) - elif obj.all_node_args: - # observer for outputs - new_observer = qconfig.activation() - insert_observer(node, new_observer) - - # insert observer for input of standalone module - if standalone_module_input_idxs is not None: - for idx in standalone_module_input_idxs: - if node.args[idx].name not in observed_node_names_set: - new_observer = qconfig.activation() - insert_observer(node.args[idx], new_observer) + insert_observer_for_output_of_the_node( + node, obj, qconfig, standalone_module_input_idxs) else: env[node.name] = observed_graph.node_copy(node, load_arg) From a97c7e2ef00350babd1a2cf336db48f677a3bdf4 Mon Sep 17 00:00:00 2001 From: Ilia Cherniavskii Date: Thu, 12 Nov 2020 21:43:55 -0800 Subject: [PATCH 93/93] Profiler benchmark fix (#47713) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47713 Fix the import and also always use internal Timer Test Plan: python benchmarks/profiler_benchmark/profiler_bench.py Reviewed By: dzhulgakov Differential Revision: D24873991 Pulled By: ilia-cher fbshipit-source-id: 1c3950d7d289a4fb5bd7043ba2d842a35c263eaa --- .../profiler_benchmark/profiler_bench.py | 31 ++++++------------- 1 file changed, 9 insertions(+), 22 deletions(-) diff --git a/benchmarks/profiler_benchmark/profiler_bench.py b/benchmarks/profiler_benchmark/profiler_bench.py index 6b187b03522e..77924c90a245 100644 --- a/benchmarks/profiler_benchmark/profiler_bench.py +++ b/benchmarks/profiler_benchmark/profiler_bench.py @@ -4,7 +4,7 @@ import timeit import torch -from torch.utils._benchmark import Timer +from torch.utils.benchmark import Timer PARALLEL_TASKS_NUM = 4 INTERNAL_ITER = None @@ -37,8 +37,6 @@ def parallel_task(x): parser.add_argument('--profiling_tensor_size', default=1, type=int) parser.add_argument('--workload', default='loop', type=str) parser.add_argument('--internal_iter', default=256, type=int) - parser.add_argument('--n', default=100, type=int) - parser.add_argument('--use_timer', action='store_true') parser.add_argument('--timer_min_run_time', default=100, type=int) args = parser.parse_args() @@ -47,8 +45,8 @@ def parallel_task(x): print("No CUDA available") sys.exit() - print("Payload: {}; {} iterations, N = {}\n".format( - args.workload, args.internal_iter, args.n)) + print("Payload: {}, {} iterations; timer min. runtime = {}\n".format( + args.workload, args.internal_iter, args.timer_min_run_time)) INTERNAL_ITER = args.internal_iter for profiling_enabled in [False, True]: @@ -90,20 +88,9 @@ def payload(): def payload(): return workload(input_x) - if args.use_timer: - t = Timer( - "payload()", - globals={"payload": payload}, - timer=timeit.default_timer, - ).blocked_autorange(min_run_time=args.timer_min_run_time) - print(t) - else: - runtimes = timeit.repeat(payload, repeat=args.n, number=1) - avg_time = statistics.mean(runtimes) * 1000.0 - stddev_time = statistics.stdev(runtimes) * 1000.0 - print("\tavg. time: {:.3f} ms, stddev: {:.3f} ms".format( - avg_time, stddev_time)) - if args.workload == "loop": - print("\ttime per iteration: {:.3f} ms".format( - avg_time / args.internal_iter)) - print() + t = Timer( + "payload()", + globals={"payload": payload}, + timer=timeit.default_timer, + ).blocked_autorange(min_run_time=args.timer_min_run_time) + print(t)