-
Notifications
You must be signed in to change notification settings - Fork 4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[wip] Loading nn.Module from checkpoint tutorial #2519
[wip] Loading nn.Module from checkpoint tutorial #2519
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/tutorials/2519
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit 1e0c60e: NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: 0e8d09c20c4d52256ffe3d31288aaefd166a9dd3 Pull Request resolved: #2519
[ghstack-poisoned]
ghstack-source-id: 1b46bb7c89e43751fec6f4f782587190e105d42e Pull Request resolved: #2519
[ghstack-poisoned]
ghstack-source-id: 1d1c6b824999c6877017b27ef7a947efe20e85ff Pull Request resolved: #2519
[ghstack-poisoned]
ghstack-source-id: bdd3fe67f9998a8ed6e6a2c76513894c8c8d2af1 Pull Request resolved: #2519
Thanks, @mikaylagawarecki! Can you please resubmit as a regular PR? We don't support ghtstack in this repo. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@svekars I will open a proper non-ghstack PR for landing, just using this one as a scratch for review for now if that's alright!
# 2. The user does not want to wait for the entire checkpoint to be loaded | ||
# into RAM before doing for example some per-tensor processing | ||
# | ||
# The `mmap` keyword argument to `torch.load` attempts to solve the above two |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how to better explain this,
- should we go into internals of how the zip file is structured
- should we talk about don't expect it to be mmap on CUDA
############################################################################### | ||
# The [`torch.device()`](https://pytorch.org/docs/main/tensor_attributes.html#torch-device) | ||
# context manager makes sure that factory calls will be performed as if they | ||
# were passed device as an argument. However, it does not affect factory |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a new PR (not landed yet) that lets this "not affecting factory calls be overridden by function calls with a explicit device argument" be overriden with TorchFunctionMode, should that be mentioned?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this tutorial should consider the user is using the latest version of main (so after the PR is landed).
2. The `torch.device()` context manager | ||
3. The `assign` keyword argument on `nn.Module.load_state_dict()` | ||
|
||
The following snippet of code illustrates the use of the above three utilities. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like the idea of having a "tl;dr:" at the top of the tutorial, we can make it more explicit: tl;dr: if you're loading a checkpoint and want to reduce compute and memory as much as posisble, do the following:
# | ||
# The `mmap` keyword argument to `torch.load` attempts to solve the above two | ||
# problems by using an [`mmap` call](https://man7.org/linux/man-pages/man2/mmap.2.html) | ||
# on the checkpoint, so that tensor storages are memory-mapped and when they are |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I think this is a bit too much "in the middle": if you assume that the user doesn't know mmap, you should have one sentence that explains what it does (map the file on disk into virtual memory and let the OS handle loading/unloading into physical memory automatically). Or if you assume they already know, you can simplify this paragraph.
# on the checkpoint, so that tensor storages are memory-mapped and when they are | ||
# fetched from disk to memory is managed by the OS. | ||
# | ||
# Next, we consider the creation of the module. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have titles? Not sure how this renders
############################################################################### | ||
# The [`torch.device()`](https://pytorch.org/docs/main/tensor_attributes.html#torch-device) | ||
# context manager makes sure that factory calls will be performed as if they | ||
# were passed device as an argument. However, it does not affect factory |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this tutorial should consider the user is using the latest version of main (so after the PR is landed).
# other metadata a tensor carries such as `.size()` and `.stride()`, | ||
# `.requires_grad` etc. | ||
# | ||
# Next, we consider the loading of the state dictionary. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
paragraph
Stack from ghstack (oldest at bottom):