-
Notifications
You must be signed in to change notification settings - Fork 21.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add weights_only
option to torch.load
#86812
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/86812
Note: Links to docs will display an error until the docs builds have been completed. ❌ 3 Failures, 3 PendingAs of commit 2a2912f: The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
e5c20ff
to
7af2388
Compare
load_weights_only
option to torch.load
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
Can you also add some tests please?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Kudos on reimplementing the unpickler yourselves, this look very nice !
torch/serialization.py
Outdated
@@ -742,6 +746,15 @@ def load( | |||
# Load a module with 'ascii' encoding for unpickling | |||
>>> torch.load('module.pt', encoding='ascii') | |||
""" | |||
UNSAFE_MESSAGE = ( | |||
"*WARNING* Safe load failed, defaulting to standard Python pickle module, " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's seems weird to have a mode which tries safe load, and then AUTOMATICALLY tries the unsafe load after. Like, what's the point? It's not any more secure, and it also gives you a false sense of security when there is no warning because the safe load succeeded.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also thought it was strange to go with the unsafe load if safe failed, I suppose it's to avoid causing breaking changes for backwards compat.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess the question is, do we care about backward compatibility or not. I'm of two minds to be frank, but fine with flipping the logic to load
should be called with weights_only
set to False
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We do care about BC... but if you want to deprecate the old insecure API, you should raise the warning unconditionally.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's seems weird to have a mode which tries safe load, and then AUTOMATICALLY tries the unsafe load after.
@ezyang - note that this warning as written currently in the diff gets triggered only if weights_only=False
. So it doesn't violate safety.
This behavior (as implemented) makes sense only as a temporary "one release" workaround to drive people to put explicit weights_only=False
in the needed places. I.e. those would be places where this logic fails and generates a warning. Then in the follow up release we'd switch default to be weights_only=True
. Places that were handling tensors only anyway (presumably the majority) won't need to change the code, will never see the warning and will become secure in the future.
Of course we can jump to weights_only=True
right away but I'm a bit hesitant about breaking BC
torch/_weights_only_unpickler.py
Outdated
"torch.LongStorage": torch.FloatStorage, | ||
"torch.nn.parameter.Parameter": torch.nn.Parameter, | ||
"torch._utils._rebuild_parameter": torch._utils._rebuild_parameter, | ||
"torch._utils._rebuild_tensor_v2": torch._utils._rebuild_tensor_v2, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is also an implicit invariant, which is that each of these functions/constructors must be able to take arbitrary user input and "construct" it safely, without triggering arbitrary code execution. This is not obvious when implementing these functions and needs to be notated at all the spots.
43f7b75
to
a7ceca3
Compare
load_weights_only
option to torch.load
weights_only
option to torch.load
torch/serialization.py
Outdated
@@ -742,6 +746,15 @@ def load( | |||
# Load a module with 'ascii' encoding for unpickling | |||
>>> torch.load('module.pt', encoding='ascii') | |||
""" | |||
UNSAFE_MESSAGE = ( | |||
"*WARNING* Safe load failed, defaulting to standard Python pickle module, " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's seems weird to have a mode which tries safe load, and then AUTOMATICALLY tries the unsafe load after.
@ezyang - note that this warning as written currently in the diff gets triggered only if weights_only=False
. So it doesn't violate safety.
This behavior (as implemented) makes sense only as a temporary "one release" workaround to drive people to put explicit weights_only=False
in the needed places. I.e. those would be places where this logic fails and generates a warning. Then in the follow up release we'd switch default to be weights_only=True
. Places that were handling tensors only anyway (presumably the majority) won't need to change the code, will never see the warning and will become secure in the future.
Of course we can jump to weights_only=True
right away but I'm a bit hesitant about breaking BC
To validate that loading pickled object raises an error
6d90cd5
to
77fb38d
Compare
Incorporating more feedback from the discussions:
|
try: | ||
return _load(opened_zipfile, map_location, _weights_only_unpickler, **pickle_load_args) | ||
except RuntimeError as e: | ||
raise pickle.UnpicklingError(UNSAFE_MESSAGE + str(e)) from None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we sure we definitely want to suppress the internal exception?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All error messages raised by _weights_only_unpickler
are unique, but on the other hand error message is much cleaner than way.
@pytorchbot merge -f "Mac failures are spurious" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Hey @malfet. |
This addresses the security issue in default Python's `unpickler` that allows arbitrary code execution while unpickling. Restrict classes allowed to be unpicked to in `None`, `int`, `bool`, `str`, `float`, `list`, `tuple`, `dict`/`OrderedDict` as well as `torch.Size`, `torch.nn.Param` as well as `torch.Tensor` and `torch.Storage` variants. Defaults `weights_only` is set to `False`, but allows global override to safe only load via `TORCH_FORCE_WEIGHTS_ONLY_LOAD` environment variable. To some extent, addresses #52596 Pull Request resolved: #86812 Approved by: https://github.com/ezyang (cherry picked from commit 961ebca)
* Tweak several test serialization to store models state_dict (#87143) Namely, change: - `test_meta_serialization` - `test_serialization_2gb_file` - `test_pathlike_serialization` Pull Request resolved: #87143 Approved by: https://github.com/ezyang (cherry picked from commit 4a533f1) * Add `weights_only` option to `torch.load` (#86812) This addresses the security issue in default Python's `unpickler` that allows arbitrary code execution while unpickling. Restrict classes allowed to be unpicked to in `None`, `int`, `bool`, `str`, `float`, `list`, `tuple`, `dict`/`OrderedDict` as well as `torch.Size`, `torch.nn.Param` as well as `torch.Tensor` and `torch.Storage` variants. Defaults `weights_only` is set to `False`, but allows global override to safe only load via `TORCH_FORCE_WEIGHTS_ONLY_LOAD` environment variable. To some extent, addresses #52596 Pull Request resolved: #86812 Approved by: https://github.com/ezyang (cherry picked from commit 961ebca)
for ts in torch._storage_classes: | ||
rc[f"{ts.__module__}.{ts.__name__}"] = ts | ||
# Rebuild functions | ||
for f in [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is the _rebuild_qtensor missing? How can we ensure this list is up to date?
This addresses the security issue in default Python's `unpickler` that allows arbitrary code execution while unpickling. Restrict classes allowed to be unpicked to in `None`, `int`, `bool`, `str`, `float`, `list`, `tuple`, `dict`/`OrderedDict` as well as `torch.Size`, `torch.nn.Param` as well as `torch.Tensor` and `torch.Storage` variants. Defaults `weights_only` is set to `False`, but allows global override to safe only load via `TORCH_FORCE_WEIGHTS_ONLY_LOAD` environment variable. To some extent, addresses pytorch#52596 Pull Request resolved: pytorch#86812 Approved by: https://github.com/ezyang
…98479) This adds a `weights_only` option to torch.hub.load_state_dict_from_url which is helpful for loading pretrained models from potentially untrusted sources. Ex: https://github.com/d4l3k/torchdrive/blob/main/torchdrive/models/simple_bev.py#L618-L621 See #86812 for more info on weights_only Test plan: ``` pytest test/test_hub.py ``` Pull Request resolved: #98479 Approved by: https://github.com/NicolasHug
…98479) This adds a `weights_only` option to torch.hub.load_state_dict_from_url which is helpful for loading pretrained models from potentially untrusted sources. Ex: https://github.com/d4l3k/torchdrive/blob/main/torchdrive/models/simple_bev.py#L618-L621 See #86812 for more info on weights_only Test plan: ``` pytest test/test_hub.py ``` Pull Request resolved: #98479 Approved by: https://github.com/NicolasHug
This addresses the security issue in default Python's
unpickler
that allows arbitrary code execution while unpickling.Restrict classes allowed to be unpicked to in
None
,int
,bool
,str
,float
,list
,tuple
,dict
/OrderedDict
as well astorch.Size
,torch.nn.Param
as well astorch.Tensor
andtorch.Storage
variants.Defaults
weights_only
is set toFalse
, but allows global override to safe only load viaTORCH_FORCE_WEIGHTS_ONLY_LOAD
environment variable.To some extent, addresses #52596