Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: adding the ability to restore shapes after loading a traced model #90744

Closed
wants to merge 1 commit into from

Conversation

mnuyens
Copy link
Contributor

@mnuyens mnuyens commented Dec 13, 2022

Adds the ability to store inputs used in tracing models when calling torch.jit.save and restore the input shapes using torch.jit.load if the appropriate variables are set.

Fixes 89185

cc @gujinghui @PenghuiCheng @XiaobingSuper @jianyuh @jgong5 @mingfeima @sanchitintel @ashokei @jingxu10 @min-jean-cho @yanbing-j @Guobing-Chen @Xia-Weiwen

@pytorch-bot
Copy link

pytorch-bot bot commented Dec 13, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/90744

Note: Links to docs will display an error until the docs builds have been completed.

❗ 2 Active SEVs

There are 2 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit 60263cc:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@linux-foundation-easycla
Copy link

linux-foundation-easycla bot commented Dec 13, 2022

CLA Signed

The committers listed above are authorized under a signed CLA.

  • ✅ login: mnuyens / name: Maxwell Nuyens (45344db)

@pytorch-bot pytorch-bot bot added the release notes: jit release notes category label Dec 13, 2022
@zou3519 zou3519 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Dec 15, 2022
Copy link
Contributor

@davidberard98 davidberard98 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some comments.

What's the use case for this change? Shapes on traced graphs are, afaik, just an artifact of the way the graph is produced that's useful for debugging; they don't provide any guards or guarantees. Carrying around copies of the input data seems like a pretty large price to be able to have this information.

Also, I wonder if we can just serialize the shapes (or Type pointers) directly instead of serializing IValues? This would allow us to (a) avoid modifying the jit::Module interface to have to store values, and (b) store only the metadata we need, instead of all the data that is associated with the inputs

rewriteQuantizedConvForBC(m);
// Checking for and loading saved traced inputs
if (reader_->hasRecord("traced_inputs.pkl")) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: imo we shouldn't attempt this if restore_shapes=False; if there's errors in this step, there should be a way to skip it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, I'll fix that in the next commit.

@@ -908,6 +911,7 @@ def trace_module(
_module_class=None,
_compilation_unit=_python_cu,
example_inputs_is_kwarg=False,
_store_inputs=True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's better to default to False here:

  • I would guess that most users don't need this behavior
  • The overhead for saving these could be pretty large if input sizes are large.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to default to on if possible to make it so people who want to use a tool that accesses this information wouldn't need to retrace or change their current experience.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've made changes so that the overhead should only be ~400 bytes per tensor in the input after serializing the model

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, but I still prefer to default to false here. is there a reason it needs to be true?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nvm see your description below. let me ask around on this

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmmm I see the loading of it default to false so its fine... the raw file size on disk is less important than the memory consumption of the loaded model.

@@ -550,6 +550,50 @@ def forward(self, x):
self.assertTrue(m_buffers["buffer"].is_meta)
self.assertTrue(m_loaded_buffers["buffer"].is_meta)

def test_save_load_with_saved_traced_inputs(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you also test some forward functions that take container types (tuples, lists) as inputs? I'm a bit worried about how setInputTensorTypes handles container types. Please add variable input sizes, etc.

Also, what is the expectation for how container types behave here? e.g. should a list of tensors be annotated as a List[Tensor{dtype, size info}] or as a List[generic-Tensor]? What if the contained tensor metadata don't match each other?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did you add these tests? I don't see them

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added them to the same test, I've gone ahead and added some comments to make that more clear.

@davidberard98
Copy link
Contributor

Alternatively, you are also free to implement this on top of pytorch for your own use cases; you can create wrappers for import / export that dump input types (during export) and load & apply the shapes (during export).

@mnuyens
Copy link
Contributor Author

mnuyens commented Dec 20, 2022

Hi David,

Thanks for the comments, I'll start working on most of the changes you suggested.

The use case I'm looking for here is to allow for a tool I work with to read the shape information from the saved traced models to simplify the customer experience. I'm also looking to not require people to know that they are using this tool before tracing if possible which is why I want to default on for the saving of the information. I'll see what I can do in regards to only saving the shape to make this a less expensive operation.

@github-actions github-actions bot added the module: mkldnn Related to Intel IDEEP or oneDNN (a.k.a. mkldnn) integration label Feb 8, 2023
@mnuyens mnuyens force-pushed the master branch 2 times, most recently from c84abc3 to 62e6ab1 Compare February 9, 2023 19:28
Copy link
Contributor

@davidberard98 davidberard98 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

left comments on how to fix the bazel build

namespace torch {
namespace jit {

TypePtr getTensorType(const at::Tensor& t, bool complete);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you need to prefix these with TORCH_API (i.e. prefix all the function definitions with TORCH_API)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated in the latest version

@@ -822,6 +823,7 @@ libtorch_python_core_sources = [
"torch/csrc/dynamo/init.cpp",
"torch/csrc/functorch/init.cpp",
"torch/csrc/jit/backends/backend_init.cpp",
"torch/csrc/jit/ir/graph_utils.cpp",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you should remove this to fix the bazel build issue

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated in the latest version

Copy link
Contributor

@davidberard98 davidberard98 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@davidberard98
Copy link
Contributor

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Feb 9, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

The merge job was canceled. If you believe this is a mistake,then you can re trigger it through pytorch-bot.

@mnuyens
Copy link
Contributor Author

mnuyens commented Feb 10, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

huydhn added a commit to huydhn/pytorch that referenced this pull request Feb 11, 2023
@davidberard98
Copy link
Contributor

@huydhn do you know why this was reverted?

@huydhn
Copy link
Contributor

huydhn commented Feb 25, 2023

@huydhn do you know why this was reverted?

Oh, this is not reverted in trunk. I was trying to debug this issue #89395, and reverted the commit on my test PR. That reverted commit message shows up here "misleadingly".

@davidberard98
Copy link
Contributor

oh I see, thanks :) I wasn't able to find the internal diff but now I've found it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged module: mkldnn Related to Intel IDEEP or oneDNN (a.k.a. mkldnn) integration open source release notes: jit release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[feature request] Add ability to preserve traced shape during torch.jit.save and torch.jit.load
7 participants