Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add weights_only option to torch.load #86812

Closed
wants to merge 13 commits into from
Closed

Conversation

malfet
Copy link
Contributor

@malfet malfet commented Oct 12, 2022

This addresses the security issue in default Python's unpickler that allows arbitrary code execution while unpickling.
Restrict classes allowed to be unpicked to in None, int, bool, str, float, list, tuple, dict/OrderedDict as well as torch.Size, torch.nn.Param as well as torch.Tensor and torch.Storage variants.

Defaults weights_only is set to False, but allows global override to safe only load via TORCH_FORCE_WEIGHTS_ONLY_LOAD environment variable.

To some extent, addresses #52596

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 12, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/86812

Note: Links to docs will display an error until the docs builds have been completed.

❌ 3 Failures, 3 Pending

As of commit 2a2912f:

The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@malfet malfet force-pushed the malfet/safer-unpickler branch 3 times, most recently from e5c20ff to 7af2388 Compare October 14, 2022 22:52
@malfet malfet changed the title [WIP] Safer unpickler Add load_weights_only option to torch.load Oct 14, 2022
@malfet malfet marked this pull request as ready for review October 14, 2022 23:26
@malfet malfet requested review from ezyang and soumith October 14, 2022 23:32
@dzhulgakov dzhulgakov self-requested a review October 15, 2022 04:23
Copy link
Collaborator

@dzhulgakov dzhulgakov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

Can you also add some tests please?

torch/_weights_only_unpickler.py Outdated Show resolved Hide resolved
torch/_weights_only_unpickler.py Outdated Show resolved Hide resolved
torch/_weights_only_unpickler.py Outdated Show resolved Hide resolved
torch/_weights_only_unpickler.py Outdated Show resolved Hide resolved
torch/serialization.py Outdated Show resolved Hide resolved
torch/serialization.py Outdated Show resolved Hide resolved
Copy link

@McPatate McPatate left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kudos on reimplementing the unpickler yourselves, this look very nice !

@@ -742,6 +746,15 @@ def load(
# Load a module with 'ascii' encoding for unpickling
>>> torch.load('module.pt', encoding='ascii')
"""
UNSAFE_MESSAGE = (
"*WARNING* Safe load failed, defaulting to standard Python pickle module, "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's seems weird to have a mode which tries safe load, and then AUTOMATICALLY tries the unsafe load after. Like, what's the point? It's not any more secure, and it also gives you a false sense of security when there is no warning because the safe load succeeded.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also thought it was strange to go with the unsafe load if safe failed, I suppose it's to avoid causing breaking changes for backwards compat.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the question is, do we care about backward compatibility or not. I'm of two minds to be frank, but fine with flipping the logic to load should be called with weights_only set to False

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do care about BC... but if you want to deprecate the old insecure API, you should raise the warning unconditionally.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's seems weird to have a mode which tries safe load, and then AUTOMATICALLY tries the unsafe load after.

@ezyang - note that this warning as written currently in the diff gets triggered only if weights_only=False. So it doesn't violate safety.

This behavior (as implemented) makes sense only as a temporary "one release" workaround to drive people to put explicit weights_only=False in the needed places. I.e. those would be places where this logic fails and generates a warning. Then in the follow up release we'd switch default to be weights_only=True. Places that were handling tensors only anyway (presumably the majority) won't need to change the code, will never see the warning and will become secure in the future.

Of course we can jump to weights_only=True right away but I'm a bit hesitant about breaking BC

"torch.LongStorage": torch.FloatStorage,
"torch.nn.parameter.Parameter": torch.nn.Parameter,
"torch._utils._rebuild_parameter": torch._utils._rebuild_parameter,
"torch._utils._rebuild_tensor_v2": torch._utils._rebuild_tensor_v2,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is also an implicit invariant, which is that each of these functions/constructors must be able to take arbitrary user input and "construct" it safely, without triggering arbitrary code execution. This is not obvious when implementing these functions and needs to be notated at all the spots.

@malfet malfet changed the title Add load_weights_only option to torch.load Add weights_only option to torch.load Oct 18, 2022
torch/serialization.py Outdated Show resolved Hide resolved
@@ -742,6 +746,15 @@ def load(
# Load a module with 'ascii' encoding for unpickling
>>> torch.load('module.pt', encoding='ascii')
"""
UNSAFE_MESSAGE = (
"*WARNING* Safe load failed, defaulting to standard Python pickle module, "
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's seems weird to have a mode which tries safe load, and then AUTOMATICALLY tries the unsafe load after.

@ezyang - note that this warning as written currently in the diff gets triggered only if weights_only=False. So it doesn't violate safety.

This behavior (as implemented) makes sense only as a temporary "one release" workaround to drive people to put explicit weights_only=False in the needed places. I.e. those would be places where this logic fails and generates a warning. Then in the follow up release we'd switch default to be weights_only=True. Places that were handling tensors only anyway (presumably the majority) won't need to change the code, will never see the warning and will become secure in the future.

Of course we can jump to weights_only=True right away but I'm a bit hesitant about breaking BC

@malfet
Copy link
Contributor Author

malfet commented Oct 20, 2022

Incorporating more feedback from the discussions:

  • Allowlist all torch._tensor_types and torch._storage_types
  • Add global override to the behavior via the environment variable, but by default set it to false
  • Enable safe unpickle for dict_only usecases in unittests

try:
return _load(opened_zipfile, map_location, _weights_only_unpickler, **pickle_load_args)
except RuntimeError as e:
raise pickle.UnpicklingError(UNSAFE_MESSAGE + str(e)) from None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we sure we definitely want to suppress the internal exception?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All error messages raised by _weights_only_unpickler are unique, but on the other hand error message is much cleaner than way.

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 20, 2022
@malfet
Copy link
Contributor Author

malfet commented Oct 21, 2022

@pytorchbot merge -f "Mac failures are spurious"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@github-actions
Copy link

Hey @malfet.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

malfet added a commit that referenced this pull request Oct 21, 2022
This addresses the security issue in default Python's `unpickler` that allows arbitrary code execution while unpickling.
Restrict classes allowed to be unpicked to in `None`, `int`, `bool`, `str`, `float`, `list`, `tuple`, `dict`/`OrderedDict` as well as `torch.Size`, `torch.nn.Param` as well as  `torch.Tensor` and `torch.Storage` variants.

Defaults `weights_only` is set to `False`,  but allows global override to safe only load via `TORCH_FORCE_WEIGHTS_ONLY_LOAD` environment variable.

To some extent, addresses #52596
Pull Request resolved: #86812
Approved by: https://github.com/ezyang

(cherry picked from commit 961ebca)
atalman pushed a commit that referenced this pull request Oct 21, 2022
* Tweak several test serialization to store models state_dict (#87143)

Namely, change:
- `test_meta_serialization`
- `test_serialization_2gb_file`
- `test_pathlike_serialization`
Pull Request resolved: #87143
Approved by: https://github.com/ezyang

(cherry picked from commit 4a533f1)

* Add `weights_only` option to `torch.load` (#86812)

This addresses the security issue in default Python's `unpickler` that allows arbitrary code execution while unpickling.
Restrict classes allowed to be unpicked to in `None`, `int`, `bool`, `str`, `float`, `list`, `tuple`, `dict`/`OrderedDict` as well as `torch.Size`, `torch.nn.Param` as well as  `torch.Tensor` and `torch.Storage` variants.

Defaults `weights_only` is set to `False`,  but allows global override to safe only load via `TORCH_FORCE_WEIGHTS_ONLY_LOAD` environment variable.

To some extent, addresses #52596
Pull Request resolved: #86812
Approved by: https://github.com/ezyang

(cherry picked from commit 961ebca)
for ts in torch._storage_classes:
rc[f"{ts.__module__}.{ts.__name__}"] = ts
# Rebuild functions
for f in [
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the _rebuild_qtensor missing? How can we ensure this list is up to date?

sgrigory pushed a commit to sgrigory/pytorch that referenced this pull request Oct 28, 2022
This addresses the security issue in default Python's `unpickler` that allows arbitrary code execution while unpickling.
Restrict classes allowed to be unpicked to in `None`, `int`, `bool`, `str`, `float`, `list`, `tuple`, `dict`/`OrderedDict` as well as `torch.Size`, `torch.nn.Param` as well as  `torch.Tensor` and `torch.Storage` variants.

Defaults `weights_only` is set to `False`,  but allows global override to safe only load via `TORCH_FORCE_WEIGHTS_ONLY_LOAD` environment variable.

To some extent, addresses pytorch#52596
Pull Request resolved: pytorch#86812
Approved by: https://github.com/ezyang
@malfet malfet deleted the malfet/safer-unpickler branch November 4, 2022 04:30
pytorchmergebot pushed a commit that referenced this pull request Apr 11, 2023
…98479)

This adds a `weights_only` option to torch.hub.load_state_dict_from_url which is helpful for loading pretrained models from potentially untrusted sources.

Ex: https://github.com/d4l3k/torchdrive/blob/main/torchdrive/models/simple_bev.py#L618-L621

See #86812 for more info on weights_only

Test plan:

```
pytest test/test_hub.py
```

Pull Request resolved: #98479
Approved by: https://github.com/NicolasHug
ZainRizvi pushed a commit that referenced this pull request Apr 19, 2023
…98479)

This adds a `weights_only` option to torch.hub.load_state_dict_from_url which is helpful for loading pretrained models from potentially untrusted sources.

Ex: https://github.com/d4l3k/torchdrive/blob/main/torchdrive/models/simple_bev.py#L618-L621

See #86812 for more info on weights_only

Test plan:

```
pytest test/test_hub.py
```

Pull Request resolved: #98479
Approved by: https://github.com/NicolasHug
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants