Skip to content

Commit

Permalink
Refactor RPC matchBuiltInOp to get rid of exception swallowing (#49009)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #49009

As per the title, we should generally not have exception swalling and
this commit makes it so that if there is a true error in JIT operator
resolution, it is propagated back to the RPC callee and we don't silently
swallow any other exceptions that may happen. Swallowing the exceptions
previously resulted in hard to debug issues such as unexpected ops showing up
in profiler, and flaky tests which were fixed by
#41287

Added a unittest that validates the error that comes from `jit/pybind_utils.h`.
ghstack-source-id: 118794661

Test Plan: CI

Reviewed By: mrshenli

Differential Revision: D25392905

fbshipit-source-id: 6f93251635740bcf902824548b2bc6f9249be5f0
  • Loading branch information
rohan-varma authored and facebook-github-bot committed Dec 17, 2020
1 parent b8d98f0 commit a727bf2
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 29 deletions.
53 changes: 30 additions & 23 deletions torch/csrc/distributed/rpc/python_functions.cpp
Expand Up @@ -61,36 +61,42 @@ std::shared_ptr<Operator> matchBuiltinOp(
const py::kwargs& kwargs,
Stack& stack) {
Symbol symbol = Symbol::fromQualString(opName);
std::vector<std::shared_ptr<jit::Operator>> candidates;

std::shared_ptr<jit::Operator> matchedOperator;
if (symbol.is_aten()) {
for (const auto& op : torch::jit::getAllOperatorsFor(symbol)) {
try {
// FIXME: This is temporary solution. We should at least refactor
// ``createStackForSchema`` to avoid throwing an error.
stack = torch::jit::createStackForSchema(
op->schema(), args, kwargs, c10::nullopt);
} catch (std::runtime_error& e) {
VLOG(1) << "Couldn't match schema: " << op->schema()
<< " to args: " << args << " and kwargs: " << kwargs
<< ", reason: " << e.what();
continue;
}

// Prefer C10 ops so that they go through C10 dispatch. We expect the
// total # of possible overloaded ops to be small (i.e. it is 10 for
// torch.add) so a worst-case linear search should not incur significant
// extra overhead.
// Prefer C10 ops so that they go through C10 dispatch. We expect the
// total # of possible overloaded ops (i.e. size of below ops list) to be
// small (i.e. it is 10 for torch.add) so a worst-case linear search should
// not incur significant extra overhead.
auto ops = torch::jit::getAllOperatorsFor(symbol);
std::vector<std::shared_ptr<torch::jit::Operator>> c10OpsForSymbol;
for (auto it = ops.begin(); it != ops.end();) {
std::shared_ptr<jit::Operator> op = *it;
if (op->isC10Op()) {
return op;
c10OpsForSymbol.emplace_back(std::move(op));
it = ops.erase(it);
} else {
++it;
}
candidates.emplace_back(op);
}

// Don't throw on failures in this call, since we are not examining on all
// operators here, and the matched operator may indeed not be a c10 op.
std::pair<std::shared_ptr<torch::jit::Operator>, torch::jit::Stack>
opWithStack;
try {
opWithStack = torch::jit::getOpWithStack(c10OpsForSymbol, args, kwargs);
} catch (const std::runtime_error& e) {
opWithStack = torch::jit::getOpWithStack(ops, args, kwargs);
}
matchedOperator = std::get<0>(opWithStack);
stack = std::get<1>(opWithStack);
}

// Ensure that we generated some candidates.
// We should never hit this path, since if !matchedOperator, then the last
// call to getOpWithStack should have thrown.
TORCH_CHECK(
!candidates.empty(),
matchedOperator != nullptr,
"Failed to match operator name ",
opName,
" and arguments "
Expand All @@ -99,7 +105,8 @@ std::shared_ptr<Operator> matchBuiltinOp(
", kwargs: ",
kwargs,
") to a builtin operator");
return candidates[0];

return matchedOperator;
}

std::shared_ptr<FutureMessage> sendPythonRemoteCall(
Expand Down
30 changes: 24 additions & 6 deletions torch/csrc/jit/python/pybind_utils.h
Expand Up @@ -879,6 +879,15 @@ inline IValue argumentToIValue(
py::repr(object)),
"\nCast error details: ",
error.what()));
} catch (const py::error_already_set& error) {
throw schema_match_error(c10::str(
schema.formatTypeMismatchMsg(
argument,
friendlyTypeName(object),
argumentPosition,
py::repr(object)),
"\n Python error details: ",
error.what()));
}
}

Expand Down Expand Up @@ -1245,20 +1254,18 @@ inline py::object invokeScriptMethodFromPython(
});
}

inline py::object invokeOperatorFromPython(
inline std::pair<std::shared_ptr<Operator>, Stack> getOpWithStack(
const std::vector<std::shared_ptr<Operator>>& operations,
py::args args,
const py::kwargs& kwargs) {
Stack stack;

if (operations.size() == 1) {
const Operator& op = *operations.at(0);
std::shared_ptr<Operator> op = operations.at(0);
// Create a stack full of the arguments and keyword arguments.
stack = createStackForSchema(
op.schema(), std::move(args), kwargs, c10::nullopt);
op->schema(), std::move(args), kwargs, c10::nullopt);

pybind11::gil_scoped_release no_gil_guard;
op.getOperation()(&stack);
return std::make_pair(op, stack);
} else {
std::vector<schema_match_error> errors;
std::shared_ptr<Operator> found_op = nullptr;
Expand All @@ -1280,6 +1287,17 @@ inline py::object invokeOperatorFromPython(
throw std::runtime_error(ss.str());
}

return std::make_pair(found_op, stack);
}
}
inline py::object invokeOperatorFromPython(
const std::vector<std::shared_ptr<Operator>>& operations,
py::args args,
const py::kwargs& kwargs) {
auto opWithStack = getOpWithStack(operations, args, kwargs);
std::shared_ptr<Operator> found_op = std::get<0>(opWithStack);
Stack stack = std::get<1>(opWithStack);
{
pybind11::gil_scoped_release no_gil_guard;
found_op->getOperation()(&stack);
}
Expand Down
8 changes: 8 additions & 0 deletions torch/testing/_internal/distributed/rpc/rpc_test.py
Expand Up @@ -4859,3 +4859,11 @@ def test_device_maps_remote(self):
self.assertEqual(rref.to_here(), torch.ones(2).to(1))

rpc.shutdown()

@dist_init
def test_op_with_invalid_args(self):
dst = worker_name((self.rank + 1) % self.world_size)
with self.assertRaisesRegex(
RuntimeError, "Overloaded torch operator invoked from Python failed to many any schema"
):
rpc.rpc_sync(dst, torch.add, args=())

0 comments on commit a727bf2

Please sign in to comment.