From 912d8ae06913006274dc9eadf439348370686244 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi Date: Fri, 30 Oct 2020 11:27:44 +0000 Subject: [PATCH 1/2] Fix classmethod override argument passing. --- test/test_overrides.py | 7 ++++ .../templates/python_variable_methods.cpp | 36 +++++++++---------- torch/csrc/utils/python_arg_parser.cpp | 10 +++--- torch/csrc/utils/python_arg_parser.h | 4 +-- 4 files changed, 32 insertions(+), 25 deletions(-) diff --git a/test/test_overrides.py b/test/test_overrides.py index 4734b3bc7c91..8ec1dabc0d9d 100644 --- a/test/test_overrides.py +++ b/test/test_overrides.py @@ -822,5 +822,12 @@ def test_gradcheck(self): }) +class TestGradNewOnesOverride(TestCase): + """ Regression test for gh-47069 """ + def test_newones(self): + t = torch.tensor([1, 2]).as_subclass(SubTensor2) + n = t.new_ones((1, 2)) + self.assertEqual(type(n), SubTensor2) + if __name__ == '__main__': unittest.main() diff --git a/tools/autograd/templates/python_variable_methods.cpp b/tools/autograd/templates/python_variable_methods.cpp index 3fd833f3af9f..9a51fa9fafad 100644 --- a/tools/autograd/templates/python_variable_methods.cpp +++ b/tools/autograd/templates/python_variable_methods.cpp @@ -50,7 +50,7 @@ static PyObject * THPVariable__is_view(PyObject *self, PyObject* args) { HANDLE_TH_ERRORS if (check_has_torch_function(self)) { - return handle_torch_function(self, "_is_view"); + return handle_torch_function(self, "_is_view", args); } auto& self_ = reinterpret_cast(self)->cdata; if (self_.is_view()) { @@ -154,7 +154,7 @@ static PyObject * THPVariable_get_device(PyObject* self_, PyObject* args) { HANDLE_TH_ERRORS if (check_has_torch_function(self_)) { - return handle_torch_function(self_, "get_device"); + return handle_torch_function(self_, "get_device", args, nullptr); } auto& self = reinterpret_cast(self_)->cdata; return wrap(self.get_device()); @@ -165,7 +165,7 @@ static PyObject * THPVariable_has_names(PyObject* self_, PyObject* args) { HANDLE_TH_ERRORS if (check_has_torch_function(self_)) { - return handle_torch_function(self_, "has_names"); + return handle_torch_function(self_, "has_names", args); } auto& self = reinterpret_cast(self_)->cdata; return wrap(self.has_names()); @@ -177,7 +177,7 @@ static PyObject * THPVariable_data_ptr(PyObject* self_, PyObject* args) { HANDLE_TH_ERRORS if (check_has_torch_function(self_)) { - return handle_torch_function(self_, "data_ptr"); + return handle_torch_function(self_, "data_ptr", args); } auto& self = reinterpret_cast(self_)->cdata; return wrap(self.data_ptr()); @@ -201,7 +201,7 @@ static PyObject * THPVariable_dim(PyObject* self, PyObject* args) { HANDLE_TH_ERRORS if (check_has_torch_function(self)) { - return handle_torch_function(self, "dim"); + return handle_torch_function(self, "dim", args); } auto& self_ = reinterpret_cast(self)->cdata; return THPUtils_packInt64(self_.dim()); @@ -213,7 +213,7 @@ static PyObject * THPVariable_numel(PyObject* self, PyObject* args) { HANDLE_TH_ERRORS if (check_has_torch_function(self)) { - return handle_torch_function(self, "numel"); + return handle_torch_function(self, "numel", args); } auto& self_ = reinterpret_cast(self)->cdata; return THPUtils_packInt64(self_.numel()); @@ -327,7 +327,7 @@ static bool dispatch_to_Bool(const Tensor & self) { static PyObject * THPVariable_float_scalar(PyObject* self, PyObject* args) { HANDLE_TH_ERRORS if (check_has_torch_function(self)) { - return handle_torch_function(self, "__float__"); + return handle_torch_function(self, "__float__", args); } jit::tracer::warn("Converting a tensor to a Python float", jit::tracer::WARN_PYTHON_DATAFLOW); auto& self_ = reinterpret_cast(self)->cdata; @@ -338,7 +338,7 @@ static PyObject * THPVariable_float_scalar(PyObject* self, PyObject* args) { static PyObject * THPVariable_complex_scalar(PyObject* self, PyObject* args) { HANDLE_TH_ERRORS if (check_has_torch_function(self)) { - return handle_torch_function(self, "__complex__"); + return handle_torch_function(self, "__complex__", args); } jit::tracer::warn("Converting a tensor to a Python complex", jit::tracer::WARN_PYTHON_DATAFLOW); auto& self_ = reinterpret_cast(self)->cdata; @@ -349,7 +349,7 @@ static PyObject * THPVariable_complex_scalar(PyObject* self, PyObject* args) { static PyObject * THPVariable_integral_scalar(PyObject* self, PyObject* args) { HANDLE_TH_ERRORS if (check_has_torch_function(self)) { - return handle_torch_function(self, "__int__"); + return handle_torch_function(self, "__int__", args); } jit::tracer::warn("Converting a tensor to a Python integer", jit::tracer::WARN_PYTHON_DATAFLOW); auto& self_ = reinterpret_cast(self)->cdata; @@ -368,7 +368,7 @@ static PyObject * THPVariable_integral_scalar(PyObject* self, PyObject* args) { static PyObject * THPVariable_index_scalar(PyObject* self, PyObject* args) { HANDLE_TH_ERRORS if (check_has_torch_function(self)) { - return handle_torch_function(self, "__index__"); + return handle_torch_function(self, "__index__", args); } jit::tracer::warn("Converting a tensor to a Python index", jit::tracer::WARN_PYTHON_DATAFLOW); auto& self_ = reinterpret_cast(self)->cdata; @@ -390,7 +390,7 @@ static Tensor dispatch_invert(const Tensor & self) { static PyObject * THPVariable_invert(PyObject* self, PyObject* args) { HANDLE_TH_ERRORS if (check_has_torch_function(self)) { - return handle_torch_function(self, "__invert__"); + return handle_torch_function(self, "__invert__", args); } auto& self_ = reinterpret_cast(self)->cdata; if (!isIntegralType(self_.scalar_type(), /*includeBool=*/true)) { @@ -685,7 +685,7 @@ static PyObject * THPVariable_element_size(PyObject* self, PyObject* args) { HANDLE_TH_ERRORS if (check_has_torch_function(self)) { - return handle_torch_function(self, "element_size"); + return handle_torch_function(self, "element_size", args); } auto& self_ = reinterpret_cast(self)->cdata; return THPUtils_packInt64(self_.element_size()); @@ -763,7 +763,7 @@ static PyObject * THPVariable_item(PyObject* self, PyObject* args) { HANDLE_TH_ERRORS if (check_has_torch_function(self)) { - return handle_torch_function(self, "item"); + return handle_torch_function(self, "item", args); } jit::tracer::warn("Converting a tensor to a Python number", jit::tracer::WARN_PYTHON_DATAFLOW); auto& self_ = reinterpret_cast(self)->cdata; @@ -832,7 +832,7 @@ static PyObject * THPVariable_new(PyObject* self, PyObject* args, PyObject* kwar { HANDLE_TH_ERRORS if (check_has_torch_function(self)) { - return handle_torch_function(self, "new"); + return handle_torch_function(self, "new", args, kwargs); } auto& self_ = reinterpret_cast(self)->cdata; OptionalDeviceGuard device_guard(device_of(self_)); @@ -844,7 +844,7 @@ static PyObject * THPVariable_new_ones(PyObject* self, PyObject* args, PyObject* { HANDLE_TH_ERRORS if (check_has_torch_function(self)) { - return handle_torch_function(self, "new_ones"); + return handle_torch_function(self, "new_ones", args, kwargs); } auto& self_ = reinterpret_cast(self)->cdata; OptionalDeviceGuard device_guard(device_of(self_)); @@ -856,7 +856,7 @@ static PyObject * THPVariable_new_tensor(PyObject* self, PyObject* args, PyObjec { HANDLE_TH_ERRORS if (check_has_torch_function(self)) { - return handle_torch_function(self, "new_tensor"); + return handle_torch_function(self, "new_ones", args, kwargs); } auto& self_ = reinterpret_cast(self)->cdata; OptionalDeviceGuard device_guard(device_of(self_)); @@ -935,7 +935,7 @@ static PyObject * THPVariable_tolist(PyObject* self, PyObject* args) { HANDLE_TH_ERRORS if (check_has_torch_function(self)) { - return handle_torch_function(self, "tolist"); + return handle_torch_function(self, "tolist", args); } jit::tracer::warn("Converting a tensor to a Python list", jit::tracer::WARN_PYTHON_DATAFLOW); auto self_ = reinterpret_cast(self)->cdata; @@ -1004,7 +1004,7 @@ static PyObject * THPVariable_type(PyObject* self, PyObject* args, PyObject* kwa static PyObject * THPVariable_bool_scalar(PyObject* self, PyObject* args) { if (check_has_torch_function(self)) { HANDLE_TH_ERRORS - return handle_torch_function(self, "__bool__"); + return handle_torch_function(self, "__bool__", args); END_HANDLE_TH_ERRORS } jit::tracer::warn("Converting a tensor to a Python boolean", jit::tracer::WARN_PYTHON_DATAFLOW); diff --git a/torch/csrc/utils/python_arg_parser.cpp b/torch/csrc/utils/python_arg_parser.cpp index ff94b1f5ceca..950e7d9fb82d 100644 --- a/torch/csrc/utils/python_arg_parser.cpp +++ b/torch/csrc/utils/python_arg_parser.cpp @@ -139,7 +139,7 @@ FunctionParameter::FunctionParameter(const std::string& fmt, bool keyword_only) auto handle_torch_function_getter(THPVariable* self, const std::string& property_name) -> PyObject* { py::object torch_api = PyObject_FastGetAttrString(THPVariableClass, (char*)property_name.c_str()); std::string module_name = "torch.Tensor." + property_name; - return handle_torch_function((PyObject *)self, "__get__", nullptr, torch_api.ptr(), module_name); + return handle_torch_function((PyObject *)self, "__get__", nullptr, nullptr, torch_api.ptr(), module_name); } auto handle_torch_function_setter(THPVariable* self, const std::string& property_name, PyObject* value) -> int { @@ -148,10 +148,10 @@ auto handle_torch_function_setter(THPVariable* self, const std::string& property if (value != nullptr) { py::tuple args_ = py::make_tuple(py::handle(value)); - handle_torch_function((PyObject *)self, "__set__", args_.ptr(), torch_api.ptr(), module_name); + handle_torch_function((PyObject *)self, "__set__", args_.ptr(), nullptr, torch_api.ptr(), module_name); } else { - handle_torch_function((PyObject *)self, "__delete__", nullptr, torch_api.ptr(), module_name); + handle_torch_function((PyObject *)self, "__delete__", nullptr, nullptr, torch_api.ptr(), module_name); } return 0; } @@ -175,13 +175,13 @@ auto combine_self_args(PyObject *self, PyObject *args) -> py::tuple { return args_; } -auto handle_torch_function(PyObject* self, const std::string& func_name, PyObject* args, PyObject* torch_api, const std::string& module_name) -> PyObject* { +auto handle_torch_function(PyObject* self, const std::string& func_name, PyObject* args, PyObject* kwargs, PyObject* torch_api, const std::string& module_name) -> PyObject* { py::object torch_api_function = PyObject_FastGetAttrString(torch_api, (char*)func_name.c_str()); TORCH_INTERNAL_ASSERT(torch_api_function.ptr() != nullptr, "torch API function must exist"); py::tuple args_ = combine_self_args(self, args); py::tuple py_types = py::make_tuple(py::handle(PyObject_Type(self))); py::object torch_function = PyObject_FastGetAttrString(self, "__torch_function__"); - py::object ret = py::reinterpret_steal(PyObject_CallFunctionObjArgs(torch_function.ptr(), torch_api_function.ptr(), py_types.ptr(), args_.ptr(), NULL)); + py::object ret = py::reinterpret_steal(PyObject_CallFunctionObjArgs(torch_function.ptr(), torch_api_function.ptr(), py_types.ptr(), args_.ptr(), kwargs)); if (ret.ptr() == nullptr) { // if an exception occurred in a user's implementation of // __torch_function__, throw it diff --git a/torch/csrc/utils/python_arg_parser.h b/torch/csrc/utils/python_arg_parser.h index 1bd94592d59f..e50792c47bec 100644 --- a/torch/csrc/utils/python_arg_parser.h +++ b/torch/csrc/utils/python_arg_parser.h @@ -820,8 +820,8 @@ auto handle_torch_function(PythonArgs &r, PyObject* self, PyObject* args, PyObje // Used for functions which needs to parse python args. auto handle_torch_function(PythonArgs &r, PyObject* args, PyObject* kwargs, PyObject* torch_api, const char* module_name) -> PyObject*; -// Used for functions that accept no keyword arguments and have no argument parsing -auto handle_torch_function(PyObject* self, const std::string& func_name, PyObject* args=nullptr, PyObject* torch_api=THPVariableClass, const std::string& module_name="torch.Tensor") -> PyObject*; +// Used for functions that have no argument parsing. +auto handle_torch_function(PyObject* self, const std::string& func_name, PyObject* args=nullptr, PyObject* kwargs=nullptr, PyObject* torch_api=THPVariableClass, const std::string& module_name="torch.Tensor") -> PyObject*; // Used for functions created in C++, e.g., C++ custom op, which doesn't use PythonArgParser to get overloaded_args. auto handle_torch_function_no_python_arg_parser(const std::vector &overloaded_args, PyObject* args, PyObject* kwargs, const char* func_name, PyObject* torch_api_function, const char* module_name) -> PyObject*; From e0ffce19606f31367228db4953b2a74d4332f2e1 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi Date: Mon, 2 Nov 2020 14:57:10 +0000 Subject: [PATCH 2/2] Fix copypasta typo. --- tools/autograd/templates/python_variable_methods.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/autograd/templates/python_variable_methods.cpp b/tools/autograd/templates/python_variable_methods.cpp index 9a51fa9fafad..4f11e44a646f 100644 --- a/tools/autograd/templates/python_variable_methods.cpp +++ b/tools/autograd/templates/python_variable_methods.cpp @@ -856,7 +856,7 @@ static PyObject * THPVariable_new_tensor(PyObject* self, PyObject* args, PyObjec { HANDLE_TH_ERRORS if (check_has_torch_function(self)) { - return handle_torch_function(self, "new_ones", args, kwargs); + return handle_torch_function(self, "new_tensor", args, kwargs); } auto& self_ = reinterpret_cast(self)->cdata; OptionalDeviceGuard device_guard(device_of(self_));