Skip to content

Commit

Permalink
Remove hacky double registration of to_here op in reg_distributed_ops (
Browse files Browse the repository at this point in the history
…pytorch#39602)

Summary:
Pull Request resolved: pytorch#39602

This was added as a part of
pytorch#38590 but we can use default arguments
here. We use fmt:;format to bind the default value to the rpc timeout at
runtime.
ghstack-source-id: 105983645

Test Plan: Ci

Differential Revision: D21912719

fbshipit-source-id: 7525c1322a95126f529301be142248af48565b82
  • Loading branch information
rohan-varma authored and xwang233 committed Jun 19, 2020
1 parent e4a6a18 commit d87e790
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 18 deletions.
Expand Up @@ -95,6 +95,7 @@
('prim::min', datetime.date(2020, 6, 30)),
('prim::max', datetime.date(2020, 6, 30)),
('aten::to_here', datetime.date(2020, 6, 30)),
('aten::to_here(RRef(t) self, double timeout*)', datetime.date(2020, 6, 30)),
('aten::local_value', datetime.date(2020, 6, 30)),
('aten::log', datetime.date(2020, 6, 30)),
('aten::__and__', datetime.date(2020, 6, 30)),
Expand Down
23 changes: 5 additions & 18 deletions torch/csrc/jit/runtime/register_distributed_ops.cpp
Expand Up @@ -10,6 +10,8 @@
#include <torch/csrc/jit/runtime/register_ops_utils.h>
#include <torch/library.h>

#include <fmt/format.h>

using at::Scalar;
using at::Tensor;
namespace dist_autograd = torch::distributed::autograd;
Expand All @@ -26,24 +28,9 @@ static auto workerInfo =

RegisterOperators reg_rpc_ops(
{Operator(
"aten::to_here(RRef(t) self) -> t(*)",
[](Stack& stack) {
auto rref = pop(stack).toRRef();
IValue res;
if (rref->isOwner()) {
res =
c10::dynamic_intrusive_pointer_cast<dist_rpc::OwnerRRef>(rref)
->getValue();
} else {
res = c10::dynamic_intrusive_pointer_cast<dist_rpc::UserRRef>(rref)
->toHere();
}
push(stack, std::move(res));
return 0;
},
aliasAnalysisFromSchema()),
Operator(
"aten::to_here(RRef(t) self, double timeout) -> t(*)",
fmt::format(
"aten::to_here(RRef(t) self, float timeout = {}) -> t(*)",
torch::distributed::rpc::kDefaultRpcTimeoutSeconds),
[](Stack& stack) {
auto timeout = pop(stack).toDouble();
auto rref = pop(stack).toRRef();
Expand Down

0 comments on commit d87e790

Please sign in to comment.