Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix classmethod override argument passing. #47114

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 7 additions & 0 deletions test/test_overrides.py
Expand Up @@ -822,5 +822,12 @@ def test_gradcheck(self):
})


class TestGradNewOnesOverride(TestCase):
""" Regression test for gh-47069 """
def test_newones(self):
t = torch.tensor([1, 2]).as_subclass(SubTensor2)
n = t.new_ones((1, 2))
self.assertEqual(type(n), SubTensor2)

if __name__ == '__main__':
unittest.main()
36 changes: 18 additions & 18 deletions tools/autograd/templates/python_variable_methods.cpp
Expand Up @@ -50,7 +50,7 @@ static PyObject * THPVariable__is_view(PyObject *self, PyObject* args)
{
HANDLE_TH_ERRORS
if (check_has_torch_function(self)) {
return handle_torch_function(self, "_is_view");
return handle_torch_function(self, "_is_view", args);
}
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
if (self_.is_view()) {
Expand Down Expand Up @@ -154,7 +154,7 @@ static PyObject * THPVariable_get_device(PyObject* self_, PyObject* args)
{
HANDLE_TH_ERRORS
if (check_has_torch_function(self_)) {
return handle_torch_function(self_, "get_device");
return handle_torch_function(self_, "get_device", args, nullptr);
}
auto& self = reinterpret_cast<THPVariable*>(self_)->cdata;
return wrap(self.get_device());
Expand All @@ -165,7 +165,7 @@ static PyObject * THPVariable_has_names(PyObject* self_, PyObject* args)
{
HANDLE_TH_ERRORS
if (check_has_torch_function(self_)) {
return handle_torch_function(self_, "has_names");
return handle_torch_function(self_, "has_names", args);
}
auto& self = reinterpret_cast<THPVariable*>(self_)->cdata;
return wrap(self.has_names());
Expand All @@ -177,7 +177,7 @@ static PyObject * THPVariable_data_ptr(PyObject* self_, PyObject* args)
{
HANDLE_TH_ERRORS
if (check_has_torch_function(self_)) {
return handle_torch_function(self_, "data_ptr");
return handle_torch_function(self_, "data_ptr", args);
}
auto& self = reinterpret_cast<THPVariable*>(self_)->cdata;
return wrap(self.data_ptr());
Expand All @@ -201,7 +201,7 @@ static PyObject * THPVariable_dim(PyObject* self, PyObject* args)
{
HANDLE_TH_ERRORS
if (check_has_torch_function(self)) {
return handle_torch_function(self, "dim");
return handle_torch_function(self, "dim", args);
}
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
return THPUtils_packInt64(self_.dim());
Expand All @@ -213,7 +213,7 @@ static PyObject * THPVariable_numel(PyObject* self, PyObject* args)
{
HANDLE_TH_ERRORS
if (check_has_torch_function(self)) {
return handle_torch_function(self, "numel");
return handle_torch_function(self, "numel", args);
}
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
return THPUtils_packInt64(self_.numel());
Expand Down Expand Up @@ -327,7 +327,7 @@ static bool dispatch_to_Bool(const Tensor & self) {
static PyObject * THPVariable_float_scalar(PyObject* self, PyObject* args) {
HANDLE_TH_ERRORS
if (check_has_torch_function(self)) {
return handle_torch_function(self, "__float__");
return handle_torch_function(self, "__float__", args);
}
jit::tracer::warn("Converting a tensor to a Python float", jit::tracer::WARN_PYTHON_DATAFLOW);
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
Expand All @@ -338,7 +338,7 @@ static PyObject * THPVariable_float_scalar(PyObject* self, PyObject* args) {
static PyObject * THPVariable_complex_scalar(PyObject* self, PyObject* args) {
HANDLE_TH_ERRORS
if (check_has_torch_function(self)) {
return handle_torch_function(self, "__complex__");
return handle_torch_function(self, "__complex__", args);
}
jit::tracer::warn("Converting a tensor to a Python complex", jit::tracer::WARN_PYTHON_DATAFLOW);
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
Expand All @@ -349,7 +349,7 @@ static PyObject * THPVariable_complex_scalar(PyObject* self, PyObject* args) {
static PyObject * THPVariable_integral_scalar(PyObject* self, PyObject* args) {
HANDLE_TH_ERRORS
if (check_has_torch_function(self)) {
return handle_torch_function(self, "__int__");
return handle_torch_function(self, "__int__", args);
}
jit::tracer::warn("Converting a tensor to a Python integer", jit::tracer::WARN_PYTHON_DATAFLOW);
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
Expand All @@ -368,7 +368,7 @@ static PyObject * THPVariable_integral_scalar(PyObject* self, PyObject* args) {
static PyObject * THPVariable_index_scalar(PyObject* self, PyObject* args) {
HANDLE_TH_ERRORS
if (check_has_torch_function(self)) {
return handle_torch_function(self, "__index__");
return handle_torch_function(self, "__index__", args);
}
jit::tracer::warn("Converting a tensor to a Python index", jit::tracer::WARN_PYTHON_DATAFLOW);
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
Expand All @@ -390,7 +390,7 @@ static Tensor dispatch_invert(const Tensor & self) {
static PyObject * THPVariable_invert(PyObject* self, PyObject* args) {
HANDLE_TH_ERRORS
if (check_has_torch_function(self)) {
return handle_torch_function(self, "__invert__");
return handle_torch_function(self, "__invert__", args);
}
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
if (!isIntegralType(self_.scalar_type(), /*includeBool=*/true)) {
Expand Down Expand Up @@ -685,7 +685,7 @@ static PyObject * THPVariable_element_size(PyObject* self, PyObject* args)
{
HANDLE_TH_ERRORS
if (check_has_torch_function(self)) {
return handle_torch_function(self, "element_size");
return handle_torch_function(self, "element_size", args);
}
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
return THPUtils_packInt64(self_.element_size());
Expand Down Expand Up @@ -763,7 +763,7 @@ static PyObject * THPVariable_item(PyObject* self, PyObject* args)
{
HANDLE_TH_ERRORS
if (check_has_torch_function(self)) {
return handle_torch_function(self, "item");
return handle_torch_function(self, "item", args);
}
jit::tracer::warn("Converting a tensor to a Python number", jit::tracer::WARN_PYTHON_DATAFLOW);
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
Expand Down Expand Up @@ -832,7 +832,7 @@ static PyObject * THPVariable_new(PyObject* self, PyObject* args, PyObject* kwar
{
HANDLE_TH_ERRORS
if (check_has_torch_function(self)) {
return handle_torch_function(self, "new");
return handle_torch_function(self, "new", args, kwargs);
}
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
OptionalDeviceGuard device_guard(device_of(self_));
Expand All @@ -844,7 +844,7 @@ static PyObject * THPVariable_new_ones(PyObject* self, PyObject* args, PyObject*
{
HANDLE_TH_ERRORS
if (check_has_torch_function(self)) {
return handle_torch_function(self, "new_ones");
return handle_torch_function(self, "new_ones", args, kwargs);
}
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
OptionalDeviceGuard device_guard(device_of(self_));
Expand All @@ -856,7 +856,7 @@ static PyObject * THPVariable_new_tensor(PyObject* self, PyObject* args, PyObjec
{
HANDLE_TH_ERRORS
if (check_has_torch_function(self)) {
return handle_torch_function(self, "new_tensor");
return handle_torch_function(self, "new_tensor", args, kwargs);
}
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
OptionalDeviceGuard device_guard(device_of(self_));
Expand Down Expand Up @@ -935,7 +935,7 @@ static PyObject * THPVariable_tolist(PyObject* self, PyObject* args)
{
HANDLE_TH_ERRORS
if (check_has_torch_function(self)) {
return handle_torch_function(self, "tolist");
return handle_torch_function(self, "tolist", args);
}
jit::tracer::warn("Converting a tensor to a Python list", jit::tracer::WARN_PYTHON_DATAFLOW);
auto self_ = reinterpret_cast<THPVariable*>(self)->cdata;
Expand Down Expand Up @@ -1004,7 +1004,7 @@ static PyObject * THPVariable_type(PyObject* self, PyObject* args, PyObject* kwa
static PyObject * THPVariable_bool_scalar(PyObject* self, PyObject* args) {
if (check_has_torch_function(self)) {
HANDLE_TH_ERRORS
return handle_torch_function(self, "__bool__");
return handle_torch_function(self, "__bool__", args);
END_HANDLE_TH_ERRORS
}
jit::tracer::warn("Converting a tensor to a Python boolean", jit::tracer::WARN_PYTHON_DATAFLOW);
Expand Down
10 changes: 5 additions & 5 deletions torch/csrc/utils/python_arg_parser.cpp
Expand Up @@ -139,7 +139,7 @@ FunctionParameter::FunctionParameter(const std::string& fmt, bool keyword_only)
auto handle_torch_function_getter(THPVariable* self, const std::string& property_name) -> PyObject* {
py::object torch_api = PyObject_FastGetAttrString(THPVariableClass, (char*)property_name.c_str());
std::string module_name = "torch.Tensor." + property_name;
return handle_torch_function((PyObject *)self, "__get__", nullptr, torch_api.ptr(), module_name);
return handle_torch_function((PyObject *)self, "__get__", nullptr, nullptr, torch_api.ptr(), module_name);
}

auto handle_torch_function_setter(THPVariable* self, const std::string& property_name, PyObject* value) -> int {
Expand All @@ -148,10 +148,10 @@ auto handle_torch_function_setter(THPVariable* self, const std::string& property
if (value != nullptr)
{
py::tuple args_ = py::make_tuple(py::handle(value));
handle_torch_function((PyObject *)self, "__set__", args_.ptr(), torch_api.ptr(), module_name);
handle_torch_function((PyObject *)self, "__set__", args_.ptr(), nullptr, torch_api.ptr(), module_name);
}
else {
handle_torch_function((PyObject *)self, "__delete__", nullptr, torch_api.ptr(), module_name);
handle_torch_function((PyObject *)self, "__delete__", nullptr, nullptr, torch_api.ptr(), module_name);
}
return 0;
}
Expand All @@ -175,13 +175,13 @@ auto combine_self_args(PyObject *self, PyObject *args) -> py::tuple {
return args_;
}

auto handle_torch_function(PyObject* self, const std::string& func_name, PyObject* args, PyObject* torch_api, const std::string& module_name) -> PyObject* {
auto handle_torch_function(PyObject* self, const std::string& func_name, PyObject* args, PyObject* kwargs, PyObject* torch_api, const std::string& module_name) -> PyObject* {
py::object torch_api_function = PyObject_FastGetAttrString(torch_api, (char*)func_name.c_str());
TORCH_INTERNAL_ASSERT(torch_api_function.ptr() != nullptr, "torch API function must exist");
py::tuple args_ = combine_self_args(self, args);
py::tuple py_types = py::make_tuple(py::handle(PyObject_Type(self)));
py::object torch_function = PyObject_FastGetAttrString(self, "__torch_function__");
py::object ret = py::reinterpret_steal<py::object>(PyObject_CallFunctionObjArgs(torch_function.ptr(), torch_api_function.ptr(), py_types.ptr(), args_.ptr(), NULL));
py::object ret = py::reinterpret_steal<py::object>(PyObject_CallFunctionObjArgs(torch_function.ptr(), torch_api_function.ptr(), py_types.ptr(), args_.ptr(), kwargs));
if (ret.ptr() == nullptr) {
// if an exception occurred in a user's implementation of
// __torch_function__, throw it
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/utils/python_arg_parser.h
Expand Up @@ -820,8 +820,8 @@ auto handle_torch_function(PythonArgs &r, PyObject* self, PyObject* args, PyObje
// Used for functions which needs to parse python args.
auto handle_torch_function(PythonArgs &r, PyObject* args, PyObject* kwargs, PyObject* torch_api, const char* module_name) -> PyObject*;

// Used for functions that accept no keyword arguments and have no argument parsing
auto handle_torch_function(PyObject* self, const std::string& func_name, PyObject* args=nullptr, PyObject* torch_api=THPVariableClass, const std::string& module_name="torch.Tensor") -> PyObject*;
// Used for functions that have no argument parsing.
auto handle_torch_function(PyObject* self, const std::string& func_name, PyObject* args=nullptr, PyObject* kwargs=nullptr, PyObject* torch_api=THPVariableClass, const std::string& module_name="torch.Tensor") -> PyObject*;

// Used for functions created in C++, e.g., C++ custom op, which doesn't use PythonArgParser to get overloaded_args.
auto handle_torch_function_no_python_arg_parser(const std::vector<py::handle> &overloaded_args, PyObject* args, PyObject* kwargs, const char* func_name, PyObject* torch_api_function, const char* module_name) -> PyObject*;
Expand Down