Skip to content

Conversation

anijain2305
Copy link
Contributor

@anijain2305 anijain2305 commented Sep 25, 2023

Stack from ghstack (oldest at bottom):

This is the output for nn module guards

[DEBUG] GUARDS:
[DEBUG] hasattr(L['x'], '_dynamo_dynamic_indices') == False           # _dynamo/variables/builder.py:1356 in wrap_fx_proxy_cls
[DEBUG] ___check_obj_id(L['self'], 139820807110912)                   # for mod in self.mods:  # examples/graph_break.py:35 in forward
[DEBUG] __nn_module_guard_0(L['self']) # versions(mod=9998, _parameters=1194395, _buffers=1194397, _modules=1194423, _forward_hooks=1194405, _forward_pre_hooks=1194411, _backward_hooks=1194402, _backward_pre_hooks=1194400)  # for mod in self.mods:  # examples/graph_break.py:35 in forward
[DEBUG] ___check_obj_id(L['self'].mods[0], 139817945727568)           # for mod in self.mods:  # examples/graph_break.py:35 in forward
[DEBUG] __nn_module_guard_1(L['self'].mods[0]) # versions(mod=10001, _parameters=1194428, _buffers=1194430, _modules=1194522, _forward_hooks=1194438, _forward_pre_hooks=1194444, _backward_hooks=1194435, _backward_pre_hooks=1194433)  # for mod in self.mods:  # examples/graph_break.py:35 in forward
[DEBUG] ___check_obj_id(L['self'].mods[1], 139817945560640)           # for mod in self.mods:  # examples/graph_break.py:35 in forward
[DEBUG] __nn_module_guard_2(L['self'].mods[1]) # versions(mod=10001, _parameters=1194660, _buffers=1194662, _modules=1194753, _forward_hooks=1194670, _forward_pre_hooks=1194676, _backward_hooks=1194667, _backward_pre_hooks=1194665)  # for mod in self.mods:  # examples/graph_break.py:35 in forward
[DEBUG] ___check_obj_id(L['self'].mods[0].linear, 139817945727856)    # return self.linear(a)  # examples/graph_break.py:24 in helper
[DEBUG] __nn_module_guard_3(L['self'].mods[0].linear) # versions(mod=10004, _parameters=1470004, _buffers=1194467, _modules=1194493, _forward_hooks=1194475, _forward_pre_hooks=1194481, _backward_hooks=1194472, _backward_pre_hooks=1194470)  # return self.linear(a)  # examples/graph_break.py:24 in helper
[DEBUG] ___check_obj_id(L['self'].mods[1].linear, 139817945561120)    # return self.linear(a)  # examples/graph_break.py:24 in helper
[DEBUG] __nn_module_guard_4(L['self'].mods[1].linear) # versions(mod=10004, _parameters=1470008, _buffers=1194699, _modules=1194725, _forward_hooks=1194707, _forward_pre_hooks=1194713, _backward_hooks=1194704, _backward_pre_hooks=1194702)  # return self.linear(a)  # examples/graph_break.py:24 in helper
[DEBUG] utils_device.CURRENT_DEVICE == None                           # _dynamo/output_graph.py:373 in init_ambient_guards

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @chenyang78 @aakhundov @kadeng

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 25, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/110028

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 7c2e814 with merge base a902150 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

…ions for debugging"


This is the output for nn module guards

~~~
[DEBUG] GUARDS:
[DEBUG] hasattr(L['x'], '_dynamo_dynamic_indices') == False           # _dynamo/variables/builder.py:1356 in wrap_fx_proxy_cls
[DEBUG] ___check_obj_id(L['self'], 139820807110912)                   # for mod in self.mods:  # examples/graph_break.py:35 in forward
[DEBUG] __nn_module_guard_0(L['self']) # versions(mod=9998, _parameters=1194395, _buffers=1194397, _modules=1194423, _forward_hooks=1194405, _forward_pre_hooks=1194411, _backward_hooks=1194402, _backward_pre_hooks=1194400)  # for mod in self.mods:  # examples/graph_break.py:35 in forward
[DEBUG] ___check_obj_id(L['self'].mods[0], 139817945727568)           # for mod in self.mods:  # examples/graph_break.py:35 in forward
[DEBUG] __nn_module_guard_1(L['self'].mods[0]) # versions(mod=10001, _parameters=1194428, _buffers=1194430, _modules=1194522, _forward_hooks=1194438, _forward_pre_hooks=1194444, _backward_hooks=1194435, _backward_pre_hooks=1194433)  # for mod in self.mods:  # examples/graph_break.py:35 in forward
[DEBUG] ___check_obj_id(L['self'].mods[1], 139817945560640)           # for mod in self.mods:  # examples/graph_break.py:35 in forward
[DEBUG] __nn_module_guard_2(L['self'].mods[1]) # versions(mod=10001, _parameters=1194660, _buffers=1194662, _modules=1194753, _forward_hooks=1194670, _forward_pre_hooks=1194676, _backward_hooks=1194667, _backward_pre_hooks=1194665)  # for mod in self.mods:  # examples/graph_break.py:35 in forward
[DEBUG] ___check_obj_id(L['self'].mods[0].linear, 139817945727856)    # return self.linear(a)  # examples/graph_break.py:24 in helper
[DEBUG] __nn_module_guard_3(L['self'].mods[0].linear) # versions(mod=10004, _parameters=1470004, _buffers=1194467, _modules=1194493, _forward_hooks=1194475, _forward_pre_hooks=1194481, _backward_hooks=1194472, _backward_pre_hooks=1194470)  # return self.linear(a)  # examples/graph_break.py:24 in helper
[DEBUG] ___check_obj_id(L['self'].mods[1].linear, 139817945561120)    # return self.linear(a)  # examples/graph_break.py:24 in helper
[DEBUG] __nn_module_guard_4(L['self'].mods[1].linear) # versions(mod=10004, _parameters=1470008, _buffers=1194699, _modules=1194725, _forward_hooks=1194707, _forward_pre_hooks=1194713, _backward_hooks=1194704, _backward_pre_hooks=1194702)  # return self.linear(a)  # examples/graph_break.py:24 in helper
[DEBUG] utils_device.CURRENT_DEVICE == None                           # _dynamo/output_graph.py:373 in init_ambient_guards
~~~

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng

[ghstack-poisoned]
anijain2305 added a commit that referenced this pull request Sep 25, 2023
code_parts.append(code)

# Log the verbose part
nn_module_guard = self._extra_closure_vars[code.split("(")[0]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't quite understand what this does

Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it would be more idiomatic if the "comment" is actually just arguments to the "function"

…ions for debugging"


This is the output for nn module guards

~~~
[DEBUG] GUARDS:
[DEBUG] hasattr(L['x'], '_dynamo_dynamic_indices') == False           # _dynamo/variables/builder.py:1356 in wrap_fx_proxy_cls
[DEBUG] ___check_obj_id(L['self'], 139820807110912)                   # for mod in self.mods:  # examples/graph_break.py:35 in forward
[DEBUG] __nn_module_guard_0(L['self']) # versions(mod=9998, _parameters=1194395, _buffers=1194397, _modules=1194423, _forward_hooks=1194405, _forward_pre_hooks=1194411, _backward_hooks=1194402, _backward_pre_hooks=1194400)  # for mod in self.mods:  # examples/graph_break.py:35 in forward
[DEBUG] ___check_obj_id(L['self'].mods[0], 139817945727568)           # for mod in self.mods:  # examples/graph_break.py:35 in forward
[DEBUG] __nn_module_guard_1(L['self'].mods[0]) # versions(mod=10001, _parameters=1194428, _buffers=1194430, _modules=1194522, _forward_hooks=1194438, _forward_pre_hooks=1194444, _backward_hooks=1194435, _backward_pre_hooks=1194433)  # for mod in self.mods:  # examples/graph_break.py:35 in forward
[DEBUG] ___check_obj_id(L['self'].mods[1], 139817945560640)           # for mod in self.mods:  # examples/graph_break.py:35 in forward
[DEBUG] __nn_module_guard_2(L['self'].mods[1]) # versions(mod=10001, _parameters=1194660, _buffers=1194662, _modules=1194753, _forward_hooks=1194670, _forward_pre_hooks=1194676, _backward_hooks=1194667, _backward_pre_hooks=1194665)  # for mod in self.mods:  # examples/graph_break.py:35 in forward
[DEBUG] ___check_obj_id(L['self'].mods[0].linear, 139817945727856)    # return self.linear(a)  # examples/graph_break.py:24 in helper
[DEBUG] __nn_module_guard_3(L['self'].mods[0].linear) # versions(mod=10004, _parameters=1470004, _buffers=1194467, _modules=1194493, _forward_hooks=1194475, _forward_pre_hooks=1194481, _backward_hooks=1194472, _backward_pre_hooks=1194470)  # return self.linear(a)  # examples/graph_break.py:24 in helper
[DEBUG] ___check_obj_id(L['self'].mods[1].linear, 139817945561120)    # return self.linear(a)  # examples/graph_break.py:24 in helper
[DEBUG] __nn_module_guard_4(L['self'].mods[1].linear) # versions(mod=10004, _parameters=1470008, _buffers=1194699, _modules=1194725, _forward_hooks=1194707, _forward_pre_hooks=1194713, _backward_hooks=1194704, _backward_pre_hooks=1194702)  # return self.linear(a)  # examples/graph_break.py:24 in helper
[DEBUG] utils_device.CURRENT_DEVICE == None                           # _dynamo/output_graph.py:373 in init_ambient_guards
~~~

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng

[ghstack-poisoned]
anijain2305 added a commit that referenced this pull request Sep 26, 2023
@anijain2305
Copy link
Contributor Author

nit: it would be more idiomatic if the "comment" is actually just arguments to the "function"

Thanks that was cleaner. kwargs are ignored in the C call, just added a debug_msg kwargs.

@anijain2305
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 26, 2023
@anijain2305 anijain2305 added the topic: not user facing topic category label Sep 26, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants