-
Notifications
You must be signed in to change notification settings - Fork 25.4k
Add torch.serialization.skip_data context manager #134504
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add torch.serialization.skip_data context manager #134504
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/134504
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 0cc1afb with merge base 5a0e7a4 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
[ghstack-poisoned]
The semantic is (1) If real tensors are passed, storages bytes will not be written but zipfile metadata + sufficient space will be saved in the checkpoint for ```python import torch import torch.nn as nn sd = nn.Linear(3, 5).state_dict() torch.save(sd, 'foo.pt', metadata_only=True) print(torch.load('foo.pt', weights_only=True)) # OrderedDict([('weight', tensor([[0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.]])), ('bias', tensor([0., 0., 0., 0., 0.]))]) ``` (2) If FakeTensor is passed, space will be saved in the checkpoint for the associated storage ```python from torch._subclasses.fake_tensor import FakeTensorMode with FakeTensorMode(): m = nn.Linear(3, 5, dtype=torch.float16, device='cuda') sd = m.state_dict() torch.save(sd, 'bla.pt', metadata_only=True) print(torch.load('bla.pt', weights_only=True)) # OrderedDict([('weight', tensor([[0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.]], device='cuda:0', dtype=torch.float16)), ('bias', tensor([0., 0., 0., 0., 0.], device='cuda:0', dtype=torch.float16))]) ``` [ghstack-poisoned]
# pointing to the same data all have the same dtype. If storage is | ||
# not allocated, don't perform this check | ||
if storage.data_ptr() != 0: | ||
if str(storage.device) != "meta" and storage.data_ptr() != 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using str(storage.device) != "meta"
to avoid the deprecation message mentioned above when accessing FakeTensor data_ptr
...are there other cases when device is not meta but storage data_ptr is 0?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
xla IIRC
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, think this fix is good then to avoid accessing data_ptr specifically for FakeTensor storage
The semantic is (1) If real tensors are passed, storages bytes will not be written but zipfile metadata + sufficient space will be saved in the checkpoint for ```python import torch import torch.nn as nn sd = nn.Linear(3, 5).state_dict() torch.save(sd, 'foo.pt', metadata_only=True) print(torch.load('foo.pt', weights_only=True)) # OrderedDict([('weight', tensor([[0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.]])), ('bias', tensor([0., 0., 0., 0., 0.]))]) ``` (2) If FakeTensor is passed to `torch.save(metadata_only=True)` space will be reserved in the checkpoint for the associated storage bytes, and when loading the type will be Tensor instead of FakeTensor) ```python from torch._subclasses.fake_tensor import FakeTensorMode with FakeTensorMode(): m = nn.Linear(3, 5, dtype=torch.float16, device='cuda') sd = m.state_dict() torch.save(sd, 'bla.pt', metadata_only=True) print(torch.load('bla.pt', weights_only=True)) # OrderedDict([('weight', tensor([[0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.]], device='cuda:0', dtype=torch.float16)), ('bias', tensor([0., 0., 0., 0., 0.], device='cuda:0', dtype=torch.float16))]) ``` [ghstack-poisoned]
The semantic is (1) If real tensors are passed, storages bytes will not be written but zipfile metadata + sufficient space will be saved in the checkpoint for ```python import torch import torch.nn as nn sd = nn.Linear(3, 5).state_dict() torch.save(sd, 'foo.pt', metadata_only=True) print(torch.load('foo.pt', weights_only=True)) # OrderedDict([('weight', tensor([[0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.]])), ('bias', tensor([0., 0., 0., 0., 0.]))]) ``` (2) If FakeTensor is passed to `torch.save(metadata_only=True)` space will be reserved in the checkpoint for the associated storage bytes, and when loading the type will be Tensor instead of FakeTensor) ```python from torch._subclasses.fake_tensor import FakeTensorMode with FakeTensorMode(): m = nn.Linear(3, 5, dtype=torch.float16, device='cuda') sd = m.state_dict() torch.save(sd, 'bla.pt', metadata_only=True) print(torch.load('bla.pt', weights_only=True)) # OrderedDict([('weight', tensor([[0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.]], device='cuda:0', dtype=torch.float16)), ('bias', tensor([0., 0., 0., 0., 0.], device='cuda:0', dtype=torch.float16))]) ``` [ghstack-poisoned]
The semantic is (1) By default `torch.serialization.skip_data(materialize_fake_tensors=False)` will make torch.save skip writing storages (but reserve space for them in the checkpoint). ```python import torch import torch.nn as nn sd = nn.Linear(3, 5).state_dict() with torch.serialization.skip_data(): torch.save(sd, 'foo.pt') print(torch.load('foo.pt', weights_only=True)) ``` (2) With `torch.serialization.skip_data(materialize_fake_tensors=True)`If FakeTensor is passed to `torch.save()` the pickler will treat these FakeTensors as being "materialized" space will be reserved in the checkpoint for the associated storage bytes, and when loading the type will be Tensor instead of FakeTensor) ```python import torch import torch.nn as nn from torch._subclasses.fake_tensor import FakeTensorMode with FakeTensorMode(): m = nn.Linear(3, 5, dtype=torch.float16, device='cuda') sd = m.state_dict() with torch.serialization.skip_data(materialize_fake_tensors=True): torch.save(sd, 'bla.pt') print(torch.load('bla.pt', weights_only=True)) # OrderedDict([('weight', tensor([[0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.]], device='cuda:0', dtype=torch.float16)), ('bias', tensor([0., 0., 0., 0., 0.], device='cuda:0', dtype=torch.float16))]) ``` [ghstack-poisoned]
The semantic is (1) By default `torch.serialization.skip_data(materialize_fake_tensors=False)` will make torch.save skip writing storages (but reserve space for them in the checkpoint). ```python import torch import torch.nn as nn sd = nn.Linear(3, 5).state_dict() with torch.serialization.skip_data(): torch.save(sd, 'foo.pt') print(torch.load('foo.pt', weights_only=True)) ``` (2) With `torch.serialization.skip_data(materialize_fake_tensors=True)`If FakeTensor is passed to `torch.save()` the pickler will treat these FakeTensors as being "materialized" space will be reserved in the checkpoint for the associated storage bytes, and when loading the type will be Tensor instead of FakeTensor) ```python import torch import torch.nn as nn from torch._subclasses.fake_tensor import FakeTensorMode with FakeTensorMode(): m = nn.Linear(3, 5, dtype=torch.float16, device='cuda') sd = m.state_dict() with torch.serialization.skip_data(materialize_fake_tensors=True): torch.save(sd, 'bla.pt') print(torch.load('bla.pt', weights_only=True)) # OrderedDict([('weight', tensor([[0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.]], device='cuda:0', dtype=torch.float16)), ('bias', tensor([0., 0., 0., 0., 0.], device='cuda:0', dtype=torch.float16))]) ``` [ghstack-poisoned]
The semantic is (1) By default `torch.serialization.skip_data(materialize_fake_tensors=False)` will make torch.save skip writing storages (but reserve space for them in the checkpoint). ```python import torch import torch.nn as nn sd = nn.Linear(3, 5).state_dict() with torch.serialization.skip_data(): torch.save(sd, 'foo.pt') print(torch.load('foo.pt', weights_only=True)) ``` (2) With `torch.serialization.skip_data(materialize_fake_tensors=True)`If FakeTensor is passed to `torch.save()` the pickler will treat these FakeTensors as being "materialized" space will be reserved in the checkpoint for the associated storage bytes, and when loading the type will be Tensor instead of FakeTensor) ```python import torch import torch.nn as nn from torch._subclasses.fake_tensor import FakeTensorMode with FakeTensorMode(): m = nn.Linear(3, 5, dtype=torch.float16, device='cuda') sd = m.state_dict() with torch.serialization.skip_data(materialize_fake_tensors=True): torch.save(sd, 'bla.pt') print(torch.load('bla.pt', weights_only=True)) # OrderedDict([('weight', tensor([[0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.]], device='cuda:0', dtype=torch.float16)), ('bias', tensor([0., 0., 0., 0., 0.], device='cuda:0', dtype=torch.float16))]) ``` [ghstack-poisoned]
The semantic is (1) By default `torch.serialization.skip_data(materialize_fake_tensors=False)` will make `torch.save` skip writing storages (but reserve space for them in the checkpoint). ```python import torch import torch.nn as nn sd = nn.Linear(3, 5).state_dict() with torch.serialization.skip_data(): torch.save(sd, 'foo.pt') print(torch.load('foo.pt', weights_only=True)) ``` (2) With `torch.serialization.skip_data(materialize_fake_tensors=True)`If FakeTensor is passed to `torch.save` the pickler will treat these FakeTensors as being "materialized" space will be reserved in the checkpoint for the associated storage bytes, and when loading the type will be Tensor instead of FakeTensor) ```python import torch import torch.nn as nn from torch._subclasses.fake_tensor import FakeTensorMode with FakeTensorMode(): m = nn.Linear(3, 5, dtype=torch.float16, device='cuda') sd = m.state_dict() with torch.serialization.skip_data(materialize_fake_tensors=True): torch.save(sd, 'bla.pt') print(torch.load('bla.pt', weights_only=True)) # OrderedDict([('weight', tensor([[0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.]], device='cuda:0', dtype=torch.float16)), ('bias', tensor([0., 0., 0., 0., 0.], device='cuda:0', dtype=torch.float16))]) ``` [ghstack-poisoned]
) | ||
state = torch._utils._get_obj_state(self) | ||
if type(self) is Tensor and not state: | ||
# Ignore all state when using FakeTensor with skip_data(materialize_fake_tensors) because FakeTensor has |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is sub-optimal in the long term. We should have a clearer contract on how our FakeTensor -> real Tensor with no data.
Dropping every single field from it might not be acceptable for everyone and might be interesting if there was a method to extract such Tensor from it easily.
This sounds ok as a temporary solution here I think
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, technically what we want is to drop every attribute that is in the dict of a "normal" FakeTensor, but keep anything else, is that right?
But agreed that for now it seems tricky to handle this because I'm not sure what is in a "normal" FakeTensor's dict is stable
torch/_tensor.py
Outdated
), | ||
( | ||
isinstance( | ||
self, torch._subclasses.functional_tensor.FunctionalTensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we just split functional + fake Tensor to another condition just above? I have to admit I can't read this condition anymore :D
torch/serialization.py
Outdated
# (1) map_location (needed for wrapper subclasses/third party devices to torch._utils) | ||
# (2) skip_data (needed for torch.Tensor.__reduce_ex__ for skip_data ctx) | ||
# (3) materialize_fake_tensors (needed for torch.Tensor.__reduce_ex__ for skip_data ctx) | ||
_serialization_tls = threading.local() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ho I would have expected you just did torch._C._stash_obj_in_tls("_serialization_tls", _serialization_tls)
here and nothing else below?
Or if we don't know how to do it. It's ok to remote it all and open an issue as a TODO for later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed and opened issue here #134680
The semantic is (1) By default `torch.serialization.skip_data(materialize_fake_tensors=False)` will make `torch.save` skip writing storages (but reserve space for them in the checkpoint). ```python import torch import torch.nn as nn sd = nn.Linear(3, 5).state_dict() with torch.serialization.skip_data(): torch.save(sd, 'foo.pt') print(torch.load('foo.pt', weights_only=True)) ``` (2) With `torch.serialization.skip_data(materialize_fake_tensors=True)`If FakeTensor is passed to `torch.save` the pickler will treat these FakeTensors as being "materialized" space will be reserved in the checkpoint for the associated storage bytes, and when loading the type will be Tensor instead of FakeTensor) ```python import torch import torch.nn as nn from torch._subclasses.fake_tensor import FakeTensorMode with FakeTensorMode(): m = nn.Linear(3, 5, dtype=torch.float16, device='cuda') sd = m.state_dict() with torch.serialization.skip_data(materialize_fake_tensors=True): torch.save(sd, 'bla.pt') print(torch.load('bla.pt', weights_only=True)) # OrderedDict([('weight', tensor([[0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.]], device='cuda:0', dtype=torch.float16)), ('bias', tensor([0., 0., 0., 0., 0.], device='cuda:0', dtype=torch.float16))]) ``` [ghstack-poisoned]
The semantic is (1) By default `torch.serialization.skip_data(materialize_fake_tensors=False)` will make `torch.save` skip writing storages (but reserve space for them in the checkpoint). ```python import torch import torch.nn as nn sd = nn.Linear(3, 5).state_dict() with torch.serialization.skip_data(): torch.save(sd, 'foo.pt') print(torch.load('foo.pt', weights_only=True)) ``` (2) With `torch.serialization.skip_data(materialize_fake_tensors=True)`If FakeTensor is passed to `torch.save` the pickler will treat these FakeTensors as being "materialized" space will be reserved in the checkpoint for the associated storage bytes, and when loading the type will be Tensor instead of FakeTensor) ```python import torch import torch.nn as nn from torch._subclasses.fake_tensor import FakeTensorMode with FakeTensorMode(): m = nn.Linear(3, 5, dtype=torch.float16, device='cuda') sd = m.state_dict() with torch.serialization.skip_data(materialize_fake_tensors=True): torch.save(sd, 'bla.pt') print(torch.load('bla.pt', weights_only=True)) # OrderedDict([('weight', tensor([[0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.]], device='cuda:0', dtype=torch.float16)), ('bias', tensor([0., 0., 0., 0., 0.], device='cuda:0', dtype=torch.float16))]) ``` [ghstack-poisoned]
## Semantic The semantic is (1) By default `torch.serialization.skip_data(materialize_fake_tensors=False)` will make `torch.save` skip writing storages (but reserve space for them in the checkpoint). ```python import torch import torch.nn as nn sd = nn.Linear(3, 5).state_dict() with torch.serialization.skip_data(): torch.save(sd, 'foo.pt') print(torch.load('foo.pt', weights_only=True)) ``` (2) With `torch.serialization.skip_data(materialize_fake_tensors=True)`If FakeTensor is passed to `torch.save` the pickler will treat these FakeTensors as being "materialized" space will be reserved in the checkpoint for the associated storage bytes, and when loading the type will be Tensor instead of FakeTensor) ```python import torch import torch.nn as nn from torch._subclasses.fake_tensor import FakeTensorMode with FakeTensorMode(): m = nn.Linear(3, 5, dtype=torch.float16, device='cuda') sd = m.state_dict() with torch.serialization.skip_data(materialize_fake_tensors=True): torch.save(sd, 'bla.pt') print(torch.load('bla.pt', weights_only=True)) # OrderedDict([('weight', tensor([[0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.]], device='cuda:0', dtype=torch.float16)), ('bias', tensor([0., 0., 0., 0., 0.], device='cuda:0', dtype=torch.float16))]) ``` ## Follow Ups - [ ] `torch.load` semantic for skip_data context manager - [ ] Mechanism for getting offsets of storages saved via this method (for writing in a separate pass) [ghstack-poisoned]
@mikaylagawarecki has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
## Semantic The semantic is (1) By default `torch.serialization.skip_data(materialize_fake_tensors=False)` will make `torch.save` skip writing storages (but reserve space for them in the checkpoint). ```python import torch import torch.nn as nn sd = nn.Linear(3, 5).state_dict() with torch.serialization.skip_data(): torch.save(sd, 'foo.pt') print(torch.load('foo.pt', weights_only=True)) ``` (2) With `torch.serialization.skip_data(materialize_fake_tensors=True)`If FakeTensor is passed to `torch.save` the pickler will treat these FakeTensors as being "materialized" space will be reserved in the checkpoint for the associated storage bytes, and when loading the type will be Tensor instead of FakeTensor) ```python import torch import torch.nn as nn from torch._subclasses.fake_tensor import FakeTensorMode with FakeTensorMode(): m = nn.Linear(3, 5, dtype=torch.float16, device='cuda') sd = m.state_dict() with torch.serialization.skip_data(materialize_fake_tensors=True): torch.save(sd, 'bla.pt') print(torch.load('bla.pt', weights_only=True)) # OrderedDict([('weight', tensor([[0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.]], device='cuda:0', dtype=torch.float16)), ('bias', tensor([0., 0., 0., 0., 0.], device='cuda:0', dtype=torch.float16))]) ``` ## Follow Ups - [ ] `torch.load` semantic for skip_data context manager - [ ] Mechanism for getting offsets of storages saved via this method (for writing in a separate pass) Differential Revision: [D62238610](https://our.internmc.facebook.com/intern/diff/D62238610) [ghstack-poisoned]
## Semantic The semantic is (1) By default `torch.serialization.skip_data(materialize_fake_tensors=False)` will make `torch.save` skip writing storages (but reserve space for them in the checkpoint). ```python import torch import torch.nn as nn sd = nn.Linear(3, 5).state_dict() with torch.serialization.skip_data(): torch.save(sd, 'foo.pt') print(torch.load('foo.pt', weights_only=True)) ``` (2) With `torch.serialization.skip_data(materialize_fake_tensors=True)`If FakeTensor is passed to `torch.save` the pickler will treat these FakeTensors as being "materialized" space will be reserved in the checkpoint for the associated storage bytes, and when loading the type will be Tensor instead of FakeTensor) ```python import torch import torch.nn as nn from torch._subclasses.fake_tensor import FakeTensorMode with FakeTensorMode(): m = nn.Linear(3, 5, dtype=torch.float16, device='cuda') sd = m.state_dict() with torch.serialization.skip_data(materialize_fake_tensors=True): torch.save(sd, 'bla.pt') print(torch.load('bla.pt', weights_only=True)) # OrderedDict([('weight', tensor([[0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.]], device='cuda:0', dtype=torch.float16)), ('bias', tensor([0., 0., 0., 0., 0.], device='cuda:0', dtype=torch.float16))]) ``` ## Follow Ups - [ ] `torch.load` semantic for skip_data context manager - [ ] Mechanism for getting offsets of storages saved via this method (for writing in a separate pass) Differential Revision: [D62238610](https://our.internmc.facebook.com/intern/diff/D62238610) [ghstack-poisoned]
@mikaylagawarecki has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Repro given no longer fails, see D62238610 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SGTM!
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
@pytorchbot merge |
The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
## Semantic The semantic is (1) By default `torch.serialization.skip_data(materialize_fake_tensors=False)` will make `torch.save` skip writing storages (but reserve space for them in the checkpoint). ```python import torch import torch.nn as nn sd = nn.Linear(3, 5).state_dict() with torch.serialization.skip_data(): torch.save(sd, 'foo.pt') print(torch.load('foo.pt', weights_only=True)) ``` (2) With `torch.serialization.skip_data(materialize_fake_tensors=True)`If FakeTensor is passed to `torch.save` the pickler will treat these FakeTensors as being "materialized" space will be reserved in the checkpoint for the associated storage bytes, and when loading the type will be Tensor instead of FakeTensor) ```python import torch import torch.nn as nn from torch._subclasses.fake_tensor import FakeTensorMode with FakeTensorMode(): m = nn.Linear(3, 5, dtype=torch.float16, device='cuda') sd = m.state_dict() with torch.serialization.skip_data(materialize_fake_tensors=True): torch.save(sd, 'bla.pt') print(torch.load('bla.pt', weights_only=True)) # OrderedDict([('weight', tensor([[0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.]], device='cuda:0', dtype=torch.float16)), ('bias', tensor([0., 0., 0., 0., 0.], device='cuda:0', dtype=torch.float16))]) ``` ## Follow Ups - [ ] `torch.load` semantic for skip_data context manager - [ ] Mechanism for getting offsets of storages saved via this method (for writing in a separate pass) Pull Request resolved: pytorch#134504 Approved by: https://github.com/albanD
…4504)" This reverts commit 202600b. Reverted pytorch#134504 on behalf of https://github.com/mikaylagawarecki due to This is breaking Windows docs tests due to NamedTemporaryFile on Windows not working well ([comment](pytorch#134504 (comment)))
## Semantic The semantic is (1) By default `torch.serialization.skip_data(materialize_fake_tensors=False)` will make `torch.save` skip writing storages (but reserve space for them in the checkpoint). ```python import torch import torch.nn as nn sd = nn.Linear(3, 5).state_dict() with torch.serialization.skip_data(): torch.save(sd, 'foo.pt') print(torch.load('foo.pt', weights_only=True)) ``` (2) With `torch.serialization.skip_data(materialize_fake_tensors=True)`If FakeTensor is passed to `torch.save` the pickler will treat these FakeTensors as being "materialized" space will be reserved in the checkpoint for the associated storage bytes, and when loading the type will be Tensor instead of FakeTensor) ```python import torch import torch.nn as nn from torch._subclasses.fake_tensor import FakeTensorMode with FakeTensorMode(): m = nn.Linear(3, 5, dtype=torch.float16, device='cuda') sd = m.state_dict() with torch.serialization.skip_data(materialize_fake_tensors=True): torch.save(sd, 'bla.pt') print(torch.load('bla.pt', weights_only=True)) # OrderedDict([('weight', tensor([[0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.]], device='cuda:0', dtype=torch.float16)), ('bias', tensor([0., 0., 0., 0., 0.], device='cuda:0', dtype=torch.float16))]) ``` ## Follow Ups - [ ] `torch.load` semantic for skip_data context manager - [ ] Mechanism for getting offsets of storages saved via this method (for writing in a separate pass) Pull Request resolved: pytorch#134504 Approved by: https://github.com/albanD
…4504)" This reverts commit 94db935. Reverted pytorch#134504 on behalf of https://github.com/kit1980 due to See D62082697 ([comment](pytorch#134504 (comment)))
## Semantic The semantic is (1) By default `torch.serialization.skip_data(materialize_fake_tensors=False)` will make `torch.save` skip writing storages (but reserve space for them in the checkpoint). ```python import torch import torch.nn as nn sd = nn.Linear(3, 5).state_dict() with torch.serialization.skip_data(): torch.save(sd, 'foo.pt') print(torch.load('foo.pt', weights_only=True)) ``` (2) With `torch.serialization.skip_data(materialize_fake_tensors=True)`If FakeTensor is passed to `torch.save` the pickler will treat these FakeTensors as being "materialized" space will be reserved in the checkpoint for the associated storage bytes, and when loading the type will be Tensor instead of FakeTensor) ```python import torch import torch.nn as nn from torch._subclasses.fake_tensor import FakeTensorMode with FakeTensorMode(): m = nn.Linear(3, 5, dtype=torch.float16, device='cuda') sd = m.state_dict() with torch.serialization.skip_data(materialize_fake_tensors=True): torch.save(sd, 'bla.pt') print(torch.load('bla.pt', weights_only=True)) # OrderedDict([('weight', tensor([[0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.]], device='cuda:0', dtype=torch.float16)), ('bias', tensor([0., 0., 0., 0., 0.], device='cuda:0', dtype=torch.float16))]) ``` ## Follow Ups - [ ] `torch.load` semantic for skip_data context manager - [ ] Mechanism for getting offsets of storages saved via this method (for writing in a separate pass) Differential Revision: [D62238610](https://our.internmc.facebook.com/intern/diff/D62238610) Pull Request resolved: pytorch#134504 Approved by: https://github.com/albanD
Semantic
The semantic is
(1) By default
torch.serialization.skip_data(materialize_fake_tensors=False)
will maketorch.save
skip writing storages (but reserve space for them in the checkpoint).(2) With
torch.serialization.skip_data(materialize_fake_tensors=True)
If FakeTensor is passed totorch.save
the pickler will treat these FakeTensors as being "materialized" space will be reserved in the checkpoint for the associated storage bytes, and when loading the type will be Tensor instead of FakeTensor)Follow Ups
torch.load
semantic for skip_data context managerStack from ghstack (oldest at bottom):
Differential Revision: D62238610