-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feature: adding the ability to restore shapes after loading a traced model #90744
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/90744
Note: Links to docs will display an error until the docs builds have been completed. ❗ 2 Active SEVsThere are 2 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit 60263cc: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left some comments.
What's the use case for this change? Shapes on traced graphs are, afaik, just an artifact of the way the graph is produced that's useful for debugging; they don't provide any guards or guarantees. Carrying around copies of the input data seems like a pretty large price to be able to have this information.
Also, I wonder if we can just serialize the shapes (or Type pointers) directly instead of serializing IValues? This would allow us to (a) avoid modifying the jit::Module interface to have to store values, and (b) store only the metadata we need, instead of all the data that is associated with the inputs
rewriteQuantizedConvForBC(m); | ||
// Checking for and loading saved traced inputs | ||
if (reader_->hasRecord("traced_inputs.pkl")) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: imo we shouldn't attempt this if restore_shapes=False
; if there's errors in this step, there should be a way to skip it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, I'll fix that in the next commit.
@@ -908,6 +911,7 @@ def trace_module( | |||
_module_class=None, | |||
_compilation_unit=_python_cu, | |||
example_inputs_is_kwarg=False, | |||
_store_inputs=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's better to default to False here:
- I would guess that most users don't need this behavior
- The overhead for saving these could be pretty large if input sizes are large.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd like to default to on if possible to make it so people who want to use a tool that accesses this information wouldn't need to retrace or change their current experience.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've made changes so that the overhead should only be ~400 bytes per tensor in the input after serializing the model
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry, but I still prefer to default to false here. is there a reason it needs to be true?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nvm see your description below. let me ask around on this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmmmm I see the loading of it default to false so its fine... the raw file size on disk is less important than the memory consumption of the loaded model.
@@ -550,6 +550,50 @@ def forward(self, x): | |||
self.assertTrue(m_buffers["buffer"].is_meta) | |||
self.assertTrue(m_loaded_buffers["buffer"].is_meta) | |||
|
|||
def test_save_load_with_saved_traced_inputs(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you also test some forward functions that take container types (tuples, lists) as inputs? I'm a bit worried about how setInputTensorTypes
handles container types. Please add variable input sizes, etc.
Also, what is the expectation for how container types behave here? e.g. should a list of tensors be annotated as a List[Tensor{dtype, size info}] or as a List[generic-Tensor]? What if the contained tensor metadata don't match each other?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will do
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
did you add these tests? I don't see them
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added them to the same test, I've gone ahead and added some comments to make that more clear.
Alternatively, you are also free to implement this on top of pytorch for your own use cases; you can create wrappers for import / export that dump input types (during export) and load & apply the shapes (during export). |
Hi David, Thanks for the comments, I'll start working on most of the changes you suggested. The use case I'm looking for here is to allow for a tool I work with to read the shape information from the saved traced models to simplify the customer experience. I'm also looking to not require people to know that they are using this tool before tracing if possible which is why I want to default on for the saving of the information. I'll see what I can do in regards to only saving the shape to make this a less expensive operation. |
c84abc3
to
62e6ab1
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
left comments on how to fix the bazel build
torch/csrc/jit/ir/graph_utils.h
Outdated
namespace torch { | ||
namespace jit { | ||
|
||
TypePtr getTensorType(const at::Tensor& t, bool complete); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you need to prefix these with TORCH_API
(i.e. prefix all the function definitions with TORCH_API)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated in the latest version
build_variables.bzl
Outdated
@@ -822,6 +823,7 @@ libtorch_python_core_sources = [ | |||
"torch/csrc/dynamo/init.cpp", | |||
"torch/csrc/functorch/init.cpp", | |||
"torch/csrc/jit/backends/backend_init.cpp", | |||
"torch/csrc/jit/ir/graph_utils.cpp", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you should remove this to fix the bazel build issue
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated in the latest version
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
The merge job was canceled. If you believe this is a mistake,then you can re trigger it through pytorch-bot. |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
… traced model (pytorch#90744)" This reverts commit 0d0ebcd.
@huydhn do you know why this was reverted? |
oh I see, thanks :) I wasn't able to find the internal diff but now I've found it. |
Adds the ability to store inputs used in tracing models when calling torch.jit.save and restore the input shapes using torch.jit.load if the appropriate variables are set.
Fixes 89185
cc @gujinghui @PenghuiCheng @XiaobingSuper @jianyuh @jgong5 @mingfeima @sanchitintel @ashokei @jingxu10 @min-jean-cho @yanbing-j @Guobing-Chen @Xia-Weiwen