Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion torch/_dynamo/guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,9 @@ def NN_MODULE(self, guard: Guard):
)
return
name = self.check_fn_manager.add_extra_closure_var("__nn_module_guard", g)
self._produce_guard_code(guard, [f"{name}({ref})"])
# debug_msg is only for debugging help and goes to kwargs of guard call,
# which is ignored.
self._produce_guard_code(guard, [f'{name}({ref}, debug_msg="{g}")'])

def FUNCTION_MATCH(self, guard: Guard):
"""things like torch.add and user defined functions"""
Expand Down
18 changes: 18 additions & 0 deletions torch/csrc/dynamo/guards.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,23 @@ static PyObject* NNModuleGuard_call(
Py_RETURN_TRUE;
}

static PyObject* NNModuleGuard_repr(PyObject* self) {
// Prints versions of the module and the attributes.
NNModuleGuard* guard = (NNModuleGuard*)self;
std::ostringstream oss;
oss << "versions(mod=" << guard->dict_version_tag;

for (size_t index = 0;
index < sizeof(module_guard_attrs) / sizeof(module_guard_attrs[0]);
index++) {
oss << ", " << module_guard_attrs[index] << "="
<< guard->attr_tags[index].dict_version_tag;
}

oss << ")";
return Py_BuildValue("s", oss.str().c_str());
}

static PyObject* nn_module_guard(PyObject* dummy, PyObject* obj) {
// Uses a private tags introduced in PEP 509 - ma_version_tag to check if
// there are any changes in the dict.
Expand Down Expand Up @@ -763,6 +780,7 @@ PyObject* torch_c_dynamo_guards_init() {
NNModuleGuardType.tp_call = NNModuleGuard_call;
NNModuleGuardType.tp_dealloc = (destructor)NNModuleGuard_dealloc;
NNModuleGuardType.tp_flags = Py_TPFLAGS_DEFAULT;
NNModuleGuardType.tp_repr = NNModuleGuard_repr;

GlobalStateGuardType.tp_name = "torch._C._dynamo.guards.GlobalStateGuard";
GlobalStateGuardType.tp_basicsize = sizeof(GlobalStateGuard);
Expand Down