Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

treat Parameter the same way as Tensor #48963

Closed
wants to merge 15 commits into from
Closed
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 8 additions & 0 deletions torch/csrc/autograd/init.cpp
Expand Up @@ -39,6 +39,14 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject *unused) {
auto _C_m = py::handle(torch_C_module).cast<py::module>();
auto m = _C_m.def_submodule("_autograd", "autograd bindings");

auto parameter_module = THPObjectPtr(PyImport_ImportModule("torch.nn.parameter"));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't really matter much, but out of curiosity, is this a py::object?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like there's an analog from TH. (Based on some grep-ing it seems to just be handling the incref and decref.) I just cargo cult'd the other ones.

if (!parameter_module)
return nullptr;

// NOTE: "leaks" ParameterClass
ParameterClass = PyObject_GetAttrString(parameter_module, "Parameter");
if (!ParameterClass)
return nullptr;

py::enum_<ProfilerState>(m, "ProfilerState")
.value("Disabled", ProfilerState::Disabled)
Expand Down
5 changes: 5 additions & 0 deletions torch/csrc/autograd/python_variable.cpp
Expand Up @@ -44,6 +44,11 @@ namespace py = pybind11;

PyObject *THPVariableClass = nullptr;

// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
PyObject *ParameterClass = nullptr;

// clang-tidy gets confused by static const
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
static const char* VOLATILE_WARNING =
"volatile was removed and now has no effect. Use "
"`with torch.no_grad():` instead.";
Expand Down
7 changes: 6 additions & 1 deletion torch/csrc/autograd/python_variable.h
Expand Up @@ -19,12 +19,17 @@ struct THPVariable {
};

THP_API PyObject *THPVariableClass;
THP_API PyObject *ParameterClass;

bool THPVariable_initModule(PyObject *module);
THP_API PyObject * THPVariable_Wrap(torch::autograd::Variable var);

static inline bool THPVariable_CheckExact(PyObject *obj) {
return Py_TYPE(obj) == (PyTypeObject*)THPVariableClass;
auto obj_py_type = Py_TYPE(obj);
return (
obj_py_type == (PyTypeObject*)THPVariableClass ||
obj_py_type == (PyTypeObject*)ParameterClass
);
}

inline bool THPVariable_Check(PyObject *obj)
Expand Down