diff --git a/torch/csrc/distributed/rpc/python_functions.cpp b/torch/csrc/distributed/rpc/python_functions.cpp index b7c16639b19b..5e2e8304b7bd 100644 --- a/torch/csrc/distributed/rpc/python_functions.cpp +++ b/torch/csrc/distributed/rpc/python_functions.cpp @@ -61,36 +61,42 @@ std::shared_ptr matchBuiltinOp( const py::kwargs& kwargs, Stack& stack) { Symbol symbol = Symbol::fromQualString(opName); - std::vector> candidates; + std::shared_ptr matchedOperator; if (symbol.is_aten()) { - for (const auto& op : torch::jit::getAllOperatorsFor(symbol)) { - try { - // FIXME: This is temporary solution. We should at least refactor - // ``createStackForSchema`` to avoid throwing an error. - stack = torch::jit::createStackForSchema( - op->schema(), args, kwargs, c10::nullopt); - } catch (std::runtime_error& e) { - VLOG(1) << "Couldn't match schema: " << op->schema() - << " to args: " << args << " and kwargs: " << kwargs - << ", reason: " << e.what(); - continue; - } - - // Prefer C10 ops so that they go through C10 dispatch. We expect the - // total # of possible overloaded ops to be small (i.e. it is 10 for - // torch.add) so a worst-case linear search should not incur significant - // extra overhead. + // Prefer C10 ops so that they go through C10 dispatch. We expect the + // total # of possible overloaded ops (i.e. size of below ops list) to be + // small (i.e. it is 10 for torch.add) so a worst-case linear search should + // not incur significant extra overhead. + auto ops = torch::jit::getAllOperatorsFor(symbol); + std::vector> c10OpsForSymbol; + for (auto it = ops.begin(); it != ops.end();) { + std::shared_ptr op = *it; if (op->isC10Op()) { - return op; + c10OpsForSymbol.emplace_back(std::move(op)); + it = ops.erase(it); + } else { + ++it; } - candidates.emplace_back(op); } + + // Don't throw on failures in this call, since we are not examining on all + // operators here, and the matched operator may indeed not be a c10 op. + std::pair, torch::jit::Stack> + opWithStack; + try { + opWithStack = torch::jit::getOpWithStack(c10OpsForSymbol, args, kwargs); + } catch (const std::runtime_error& e) { + opWithStack = torch::jit::getOpWithStack(ops, args, kwargs); + } + matchedOperator = std::get<0>(opWithStack); + stack = std::get<1>(opWithStack); } - // Ensure that we generated some candidates. + // We should never hit this path, since if !matchedOperator, then the last + // call to getOpWithStack should have thrown. TORCH_CHECK( - !candidates.empty(), + matchedOperator != nullptr, "Failed to match operator name ", opName, " and arguments " @@ -99,7 +105,8 @@ std::shared_ptr matchBuiltinOp( ", kwargs: ", kwargs, ") to a builtin operator"); - return candidates[0]; + + return matchedOperator; } std::shared_ptr sendPythonRemoteCall( diff --git a/torch/csrc/jit/python/pybind_utils.h b/torch/csrc/jit/python/pybind_utils.h index 95b47d142122..140e6f590544 100644 --- a/torch/csrc/jit/python/pybind_utils.h +++ b/torch/csrc/jit/python/pybind_utils.h @@ -879,6 +879,15 @@ inline IValue argumentToIValue( py::repr(object)), "\nCast error details: ", error.what())); + } catch (const py::error_already_set& error) { + throw schema_match_error(c10::str( + schema.formatTypeMismatchMsg( + argument, + friendlyTypeName(object), + argumentPosition, + py::repr(object)), + "\n Python error details: ", + error.what())); } } @@ -1245,20 +1254,18 @@ inline py::object invokeScriptMethodFromPython( }); } -inline py::object invokeOperatorFromPython( +inline std::pair, Stack> getOpWithStack( const std::vector>& operations, py::args args, const py::kwargs& kwargs) { Stack stack; - if (operations.size() == 1) { - const Operator& op = *operations.at(0); + std::shared_ptr op = operations.at(0); // Create a stack full of the arguments and keyword arguments. stack = createStackForSchema( - op.schema(), std::move(args), kwargs, c10::nullopt); + op->schema(), std::move(args), kwargs, c10::nullopt); - pybind11::gil_scoped_release no_gil_guard; - op.getOperation()(&stack); + return std::make_pair(op, stack); } else { std::vector errors; std::shared_ptr found_op = nullptr; @@ -1280,6 +1287,17 @@ inline py::object invokeOperatorFromPython( throw std::runtime_error(ss.str()); } + return std::make_pair(found_op, stack); + } +} +inline py::object invokeOperatorFromPython( + const std::vector>& operations, + py::args args, + const py::kwargs& kwargs) { + auto opWithStack = getOpWithStack(operations, args, kwargs); + std::shared_ptr found_op = std::get<0>(opWithStack); + Stack stack = std::get<1>(opWithStack); + { pybind11::gil_scoped_release no_gil_guard; found_op->getOperation()(&stack); } diff --git a/torch/testing/_internal/distributed/rpc/rpc_test.py b/torch/testing/_internal/distributed/rpc/rpc_test.py index a149c541a090..8eec8100270b 100644 --- a/torch/testing/_internal/distributed/rpc/rpc_test.py +++ b/torch/testing/_internal/distributed/rpc/rpc_test.py @@ -4859,3 +4859,11 @@ def test_device_maps_remote(self): self.assertEqual(rref.to_here(), torch.ones(2).to(1)) rpc.shutdown() + + @dist_init + def test_op_with_invalid_args(self): + dst = worker_name((self.rank + 1) % self.world_size) + with self.assertRaisesRegex( + RuntimeError, "Overloaded torch operator invoked from Python failed to many any schema" + ): + rpc.rpc_sync(dst, torch.add, args=())