Skip to content

Commit

Permalink
Move overloaded_args from FunctionSignature to PythonArgs (#106983)
Browse files Browse the repository at this point in the history
This moves the `overloaded_args` field from FunctionSignature to PythonArgs. FunctionSignature is shared by all calls and should be immutable. PythonArgs contains the parsing results for an single call to the PyTorch API.

I did not measure a difference in performance in the "overrides_benchmark", although I expect there to be a bit more work in the common case. Note that the noise factor for the benchmark is much larger than the differences reported below:

Before:
```
Type tensor had a minimum time of 2.3615360260009766 us and a standard deviation of 0.7833134150132537 us.
Type SubTensor had a minimum time of 10.473251342773438 us and a standard deviation of 0.1973132457351312 us.
Type WithTorchFunction had a minimum time of 5.484819412231445 us and a standard deviation of 0.13305981701705605 us.
Type SubWithTorchFunction had a minimum time of 11.098146438598633 us and a standard deviation of 0.15598918253090233 us.
```
After:
```
Type tensor had a minimum time of 2.2134780883789062 us and a standard deviation of 0.802064489107579 us.
Type SubTensor had a minimum time of 10.625839233398438 us and a standard deviation of 0.15155907021835446 us.
Type WithTorchFunction had a minimum time of 5.520820617675781 us and a standard deviation of 0.23115111980587244 us.
Type SubWithTorchFunction had a minimum time of 11.227846145629883 us and a standard deviation of 0.23032321769278497 us.
```

Fixes #106974

Pull Request resolved: #106983
Approved by: https://github.com/zou3519, https://github.com/ezyang, https://github.com/albanD
  • Loading branch information
colesbury authored and pytorchmergebot committed Aug 16, 2023
1 parent 1f6c1d9 commit d0e50d9
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 58 deletions.
4 changes: 2 additions & 2 deletions torch/csrc/PyInterpreter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ py::object torchDispatchFromTensorImpl(
PyGILState_Check(),
"GIL must be held before you call parseIValuesToPyArgsKwargs");

std::vector<py::handle> overloaded_args;
std::vector<PyObject*> overloaded_args;
// TODO: there should be a shorter way to spell this
// TODO: fix the constness of target
at::Tensor self_t = at::Tensor(
Expand Down Expand Up @@ -280,7 +280,7 @@ void ConcretePyInterpreterVTable::dispatch(

py::gil_scoped_acquire g;

std::vector<py::handle> overloaded_args;
std::vector<PyObject*> overloaded_args;
py::handle torch_api_function_overload = getTorchApiFunction(op);

// Find overloaded tensors
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/python/pybind_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -775,7 +775,7 @@ py::object _get_operation_for_overload_or_packet(
const py::kwargs& kwargs,
bool is_overload,
c10::optional<c10::DispatchKey> dk) {
std::vector<py::handle> overloaded_args;
std::vector<PyObject*> overloaded_args;
size_t total_arg_num = args.size() + kwargs.size();
for (const auto i : c10::irange(args.size())) {
is_tensor_and_append_overloaded(args[i].ptr(), &overloaded_args);
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/python/pybind_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -1017,7 +1017,7 @@ inline c10::optional<py::object> maybeTorchFunctionDispatch(
py::tuple args = py::cast(args_vec);

// Handle __torch_function__ dispatch
std::vector<py::handle> overloaded_args;
std::vector<PyObject*> overloaded_args;
size_t total_arg_num = args.size() + kwargs.size();
for (const auto& arg : args) {
is_tensor_and_append_overloaded(arg.ptr(), &overloaded_args);
Expand Down
74 changes: 33 additions & 41 deletions torch/csrc/utils/python_arg_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ auto handle_torch_function(
torch_api_function.ptr() != nullptr, "torch API function must exist");
py::tuple args_ = combine_self_args(self, args);
return handle_torch_function_no_python_arg_parser(
{py::handle(self)},
{self},
args_.ptr(),
kwargs,
func_name.c_str(),
Expand All @@ -259,7 +259,7 @@ static PyObject* get_type_of_overloaded_arg(PyObject* obj_or_type) {

// See Note: [Overloaded args] for what they hold
auto handle_torch_function_no_python_arg_parser(
at::ArrayRef<py::handle> overloaded_args,
at::ArrayRef<PyObject*> overloaded_args,
PyObject* args,
PyObject* kwargs,
const char* func_name,
Expand All @@ -283,8 +283,8 @@ auto handle_torch_function_no_python_arg_parser(
std::vector<py::object> overloaded_types;
overloaded_types.reserve(overloaded_args.size());
for (auto& arg : overloaded_args) {
overloaded_types.push_back(py::reinterpret_borrow<py::object>(
get_type_of_overloaded_arg(arg.ptr())));
overloaded_types.push_back(
py::reinterpret_borrow<py::object>(get_type_of_overloaded_arg(arg)));
}
py::tuple py_types = py::cast(overloaded_types);
py::object ret;
Expand Down Expand Up @@ -355,14 +355,14 @@ auto handle_torch_function_no_python_arg_parser(
if (ret.ptr() == nullptr || ret.ptr() == Py_NotImplemented) {
for (auto& arg : overloaded_args) {
py::object torch_function =
PyObject_FastGetAttrString(arg.ptr(), torch_function_name_str);
PyObject_FastGetAttrString(arg, torch_function_name_str);
if (!torch_function) {
TORCH_INTERNAL_ASSERT(0);
}

// See https://github.com/pytorch/pytorch/issues/63767
if (PyObject_FastGetAttrString(torch_function.ptr(), "__self__")
.is(arg) &&
.is(py::handle(arg)) &&
torch_function.ptr() != torch::disabled_torch_function_impl()) {
TORCH_WARN(
"Defining your `",
Expand Down Expand Up @@ -408,8 +408,8 @@ auto handle_torch_function_no_python_arg_parser(
ss << " - mode object " << py::repr(mode_obj) << "\n";
}
for (auto& arg : overloaded_args) {
ss << " - tensor subclass "
<< py::repr(get_type_of_overloaded_arg(arg.ptr())) << "\n";
ss << " - tensor subclass " << py::repr(get_type_of_overloaded_arg(arg))
<< "\n";
}
ss << "\nFor more information, try re-running with TORCH_LOGS=not_implemented";
const std::string& tmp = ss.str();
Expand All @@ -432,18 +432,9 @@ auto handle_torch_function(
(char*)(func_name_override ? func_name_override : r.get_func_name().c_str()));
TORCH_INTERNAL_ASSERT(
torch_api_function.ptr() != nullptr, "torch API function must exist");
py::object ret;
py::tuple args_ = combine_self_args(self, args);
// overloaded_args already all have unique types
std::vector<py::object> overloaded_types;
overloaded_types.reserve(r.signature.overloaded_args.size());
for (auto& arg : r.signature.overloaded_args) {
overloaded_types.push_back(
py::reinterpret_borrow<py::object>((PyObject*)Py_TYPE(arg.ptr())));
}
py::tuple py_types = py::cast(overloaded_types);
return handle_torch_function_no_python_arg_parser(
r.signature.overloaded_args,
r.overloaded_args,
args_.ptr(),
kwargs,
r.get_func_name().c_str(),
Expand Down Expand Up @@ -473,7 +464,7 @@ auto handle_torch_function_indexing(
} else {
index_tup = py::make_tuple(py::handle(index));
}
std::vector<py::handle> overridable_args;
std::vector<PyObject*> overridable_args;
is_tensor_and_append_overloaded(self, &overridable_args);
auto size = PyTuple_GET_SIZE(index_tup.ptr());
for (auto i : c10::irange(size)) {
Expand Down Expand Up @@ -518,7 +509,7 @@ auto handle_torch_function_indexing(
* entry in overloaded_args for this type with higher precedence than
* the superclass.
*
* See torch._overrides._get_overloaded_types_and_args for the equivalent
* See torch._overrides._get_overloaded_args for the equivalent
* function in the Python __torch_function__ implementation.
*
* The precedence-determining algorithm implemented in this function is
Expand All @@ -538,13 +529,13 @@ auto handle_torch_function_indexing(
*/

static void append_overloaded_arg(
std::vector<py::handle>* overloaded_args,
std::vector<PyObject*>* overloaded_args,
PyObject* obj,
bool obj_is_type) {
bool class_not_seen_yet = true;
PyObject* obj_type = obj_is_type ? obj : (PyObject*)Py_TYPE(obj);
for (auto& arg : *overloaded_args) {
if (obj_type == get_type_of_overloaded_arg(arg.ptr())) {
if (obj_type == get_type_of_overloaded_arg(arg)) {
// obj is the same type as another parameter we've seen in a prior
// iteration of the loop over parameters so we already have an entry
// with the proper __torch_function__ implementation to call, so skip
Expand All @@ -557,9 +548,7 @@ static void append_overloaded_arg(
auto arg_index = overloaded_args->size();
for (const auto j : c10::irange(arg_index)) {
if (PyObject_IsSubclass(
obj_type,
(PyObject*)(get_type_of_overloaded_arg(
(*overloaded_args)[j].ptr())))) {
obj_type, get_type_of_overloaded_arg((*overloaded_args)[j]))) {
// obj is a subclass of another object we've seen already so its
// __torch_function__ should be called first, therefore we
// insert it into overloaded_args before the superclass
Expand All @@ -576,20 +565,20 @@ static void append_overloaded_arg(
}

void append_overloaded_tensor(
std::vector<py::handle>* overloaded_args,
std::vector<PyObject*>* overloaded_args,
PyObject* obj) {
append_overloaded_arg(overloaded_args, obj, /*obj_is_type*/ false);
}

void append_overloaded_type(
std::vector<py::handle>* overloaded_args,
std::vector<PyObject*>* overloaded_args,
PyObject* obj) {
append_overloaded_arg(overloaded_args, obj, /*obj_is_type*/ true);
}

bool is_tensor_and_append_overloaded(
PyObject* obj,
std::vector<py::handle>* overloaded_args) {
std::vector<PyObject*>* overloaded_args) {
if (THPVariable_CheckExact(obj)) {
// torch.Tensor instances (not subclasses, except for Parameter)
return true;
Expand Down Expand Up @@ -626,7 +615,7 @@ static bool is_scalar_list(PyObject* obj) {

bool is_tensor_list_and_append_overloaded(
PyObject* obj,
std::vector<py::handle>* overloaded_args,
std::vector<PyObject*>* overloaded_args,
int argnum,
bool throw_error) {
auto tuple = six::isTuple(obj);
Expand Down Expand Up @@ -790,7 +779,7 @@ static bool is_int_or_symint_list(
// argnum is needed for raising the TypeError, it's used in the error message.
auto FunctionParameter::check(
PyObject* obj,
std::vector<py::handle>& overloaded_args,
std::vector<PyObject*>& overloaded_args,
int argnum,
int64_t* failed_idx) -> bool {
switch (type_) {
Expand Down Expand Up @@ -1347,6 +1336,7 @@ bool FunctionSignature::parse(
PyObject* args,
PyObject* kwargs,
PyObject* dst[], // NOLINT
std::vector<PyObject*>& overloaded_args,
bool raise_exception) {
Py_ssize_t nargs = args ? PyTuple_GET_SIZE(args) : 0;
auto remaining_kwargs = kwargs ? PyDict_Size(kwargs) : 0;
Expand Down Expand Up @@ -1374,13 +1364,9 @@ bool FunctionSignature::parse(
return false;
}

if (!overloaded_args.empty()) {
overloaded_args.clear();
}

int i = 0;
if (self != nullptr && check_has_torch_function(self, /*ignore_mode*/ true)) {
append_overloaded_tensor(&this->overloaded_args, self);
append_overloaded_tensor(&overloaded_args, self);
}
for (auto& param : params) {
PyObject* obj = nullptr;
Expand Down Expand Up @@ -1415,7 +1401,7 @@ bool FunctionSignature::parse(
missing_args(*this, i);
}
return false;
} else if (param.check(obj, this->overloaded_args, i, &failed_idx)) {
} else if (param.check(obj, overloaded_args, i, &failed_idx)) {
dst[i++] = obj;
// XXX: the Variable check is necessary because sizes become tensors when
// tracer is enabled. This behavior easily leads to ambiguities, and we
Expand Down Expand Up @@ -1541,15 +1527,20 @@ PythonArgs PythonArgParser::raw_parse(
PyObject* parsed_args[]) { // NOLINT
if (signatures_.size() == 1) {
auto& signature = signatures_[0];
signature.parse(self, args, kwargs, parsed_args, true);
std::vector<PyObject*> overloaded_args;
signature.parse(self, args, kwargs, parsed_args, overloaded_args, true);
check_deprecated(signature);
return PythonArgs(traceable, signature, parsed_args);
return PythonArgs(
traceable, signature, parsed_args, std::move(overloaded_args));
}

for (auto& signature : signatures_) {
if (signature.parse(self, args, kwargs, parsed_args, false)) {
std::vector<PyObject*> overloaded_args;
if (signature.parse(
self, args, kwargs, parsed_args, overloaded_args, false)) {
check_deprecated(signature);
return PythonArgs(traceable, signature, parsed_args);
return PythonArgs(
traceable, signature, parsed_args, std::move(overloaded_args));
}
}

Expand All @@ -1575,7 +1566,8 @@ void PythonArgParser::print_error(

if (plausible_idxs.size() == 1) {
auto& signature = signatures_[plausible_idxs[0]];
signature.parse(self, args, kwargs, parsed_args, true);
std::vector<PyObject*> overloaded_args;
signature.parse(self, args, kwargs, parsed_args, overloaded_args, true);
}

auto options = get_signatures();
Expand Down

0 comments on commit d0e50d9

Please sign in to comment.