-
Notifications
You must be signed in to change notification settings - Fork 22.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow tensor subclasses and add torch.serialization.add_safe_globals
that allows users to allowlist classes for weights_only
load
#124331
Conversation
…bclasses [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/124331
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit e91a1de with merge base 8619fe6 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…bclasses ghstack-source-id: 01d0789c216764e8105c6caf5fcc1b775db35af3 Pull Request resolved: #124331
…r Tensor Subclasses" [ghstack-poisoned]
torch.serialization.mark_safe_globals
that allows users to allowlist classes for weights_only
torch.serialization.mark_safe_globals
that allows users to allowlist classes for weights_onlytorch.serialization.mark_safe_globals
that allows users to allowlist classes for weights_only
load
…ers to allowlist classes for `weights_only` load" [ghstack-poisoned]
…ers to allowlist classes for `weights_only` load" Allowlisted `torch._utils._rebuild_wrapper_subclass` and `torch.device` (used by `_rebuild_wrapper_subclass`) and added `torch.serialization.mark_safe_globals` that enables user allowlisting of classes they have deemed safe + simple test with `TwoTensor`. This API is primarily intended to be used for checkpoints where user-defined Tensor Subclasses have been serialized. Next steps: - Add testing and allowlist required classes for all in-core tensor subclasses (e.g. `DTensor`, `FakeTensor` etc.) [ghstack-poisoned]
…bclasses ghstack-source-id: 5dfbbad41105e5f9d4e89f2e227e46ecf3b7b732 Pull Request resolved: #124331
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Marking as draft after discussion with Alban -- while we will likely still need a way for users to allowlist specific classes, want to also look into whether there is a generic way to allowlist tensor subclasses (with appropriate guardrails for security)
…ers to allowlist classes for `weights_only` load" Allowlisted `torch._utils._rebuild_wrapper_subclass` and `torch.device` (used by `_rebuild_wrapper_subclass`) and added `torch.serialization.mark_safe_globals` that enables user allowlisting of classes they have deemed safe + simple test with `TwoTensor`. This API is primarily intended to be used for checkpoints where user-defined Tensor Subclasses have been serialized. Next steps: - Add testing and allowlist required classes for all in-core tensor subclasses (e.g. `DTensor`, `FakeTensor` etc.) [ghstack-poisoned]
torch/_weights_only_unpickler.py
Outdated
else: | ||
raise RuntimeError(f"Unsupported class {full_path}") | ||
class_type = getattr(modules[module], name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@albanD I think there might be a fundamental problem here that prevents us from allowlisting all tensor subclasses.
The GLOBAL <bla>
in the pickle file will have as a string, in order to check issubclass(bla, torch.Tensor)
, we need to convert the string bla
to the actual type object. To my knowledge, I could not find a safe way to do this, and I suspect this line here might not be what we consider secure. wdyt?
Is there a way where we can constrain this such that it is still secure?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From offline discussion, a good middle ground is to put the responsibility to import the module on the user
…ers to allowlist classes for `weights_only` load" Allowlisted `torch._utils._rebuild_wrapper_subclass` and `torch.device` (used by `_rebuild_wrapper_subclass`) and added `torch.serialization.mark_safe_globals` that enables user allowlisting of classes they have deemed safe + simple test with `TwoTensor`. This API is primarily intended to be used for checkpoints where user-defined Tensor Subclasses have been serialized. Next steps: - Add testing and allowlist required classes for all in-core tensor subclasses (e.g. `DTensor`, `FakeTensor` etc.) [ghstack-poisoned]
…bclasses ghstack-source-id: 467f6c37d45a486ad9f5a17faa0bdf67f990ad76 Pull Request resolved: #124331
…ers to allowlist classes for `weights_only` load" Allowlisted `torch._utils._rebuild_wrapper_subclass` and `torch.device` (used by `_rebuild_wrapper_subclass`) and added `torch.serialization.mark_safe_globals` that enables user allowlisting of classes they have deemed safe + simple test with `TwoTensor`. This API is primarily intended to be used for checkpoints where user-defined Tensor Subclasses have been serialized. Next steps: - Add testing and allowlist required classes for all in-core tensor subclasses (e.g. `DTensor`, `FakeTensor` etc.) [ghstack-poisoned]
torch.serialization.mark_safe_globals
that allows users to allowlist classes for weights_only
loadtorch.serialization.mark_safe_globals
that allows users to allowlist classes for weights_only
load
…safe_globals` that allows users to allowlist classes for `weights_only` load" #### Conditions for allowlisting tensor subclasses We allowlist tensor subclasses that (1) Do not override `__setstate__` (2) Use the generic `tp_alloc` The rationale for these two conditions is as follows: The rebuild func provided by `Tensor.__reduce_ex__` is `torch._tensor._rebuild_from_type_v2`, which is defined as such (note the call to `Tensor.__setstate__` and the call to `as_subclass`) https://github.com/pytorch/pytorch/blob/4e66aaa01092ddc8822bbca315b673329c76f4cd/torch/_tensor.py#L57-L71 `as_subclass` is implemented with a call to `THPVariable_NewWithVar` that will eventually call `tp_alloc` here https://github.com/pytorch/pytorch/blob/4e66aaa01092ddc8822bbca315b673329c76f4cd/torch/csrc/autograd/python_variable.cpp#L2053 The `func` arg to `_rebuild_from_type_v2` for wrapper subclasses is `Tensor.rebuild_wrapper_subclass`, which will similarly call into `THPVariable_NewWithVar` and hit the above `tp_alloc` **Note that we do not call `tp_init` or `tp_new` (i.e. `cls.__init__` or `cls.__new__`) when unpickling** ### How do we check something is a tensor subclass/constraints around imports In order to check whether `bla` is a tensor subclass in the bytecode `GLOBAL module.name`, we need to do an `issubclass` check, which entails converting the global string to the appropriate type. We *do not* arbitrarily import modules but will perform this check as long as the given subclass (given by `module.name`) has already been imported by the user (i.e. `module in sys.modules` and `issubclass(getattr(sys[modules], name), torch.Tensor)` This PR also allowlisted `torch._utils._rebuild_wrapper_subclass` and `torch.device` (used by `_rebuild_wrapper_subclass`) ### API for allow listing This PR also added `torch.serialization.mark_safe_globals` that enables user to allowlist globals they have deemed safe (for example they could allowlist a tensor subclass with a custom `__setstate__` if they have checked that this is safe. Next steps: - Add testing and allowlist required classes for all in-core tensor subclasses (e.g. `DTensor`, `FakeTensor` etc.) [ghstack-poisoned]
…safe_globals` that allows users to allowlist classes for `weights_only` load" #### Conditions for allowlisting tensor subclasses We allowlist tensor subclasses that (1) Do not override `__setstate__` (2) Use the generic `tp_alloc` The rationale for these two conditions is as follows: The rebuild func provided by `Tensor.__reduce_ex__` is `torch._tensor._rebuild_from_type_v2`, which is defined as such (note the call to `Tensor.__setstate__` and the call to `as_subclass`) https://github.com/pytorch/pytorch/blob/4e66aaa01092ddc8822bbca315b673329c76f4cd/torch/_tensor.py#L57-L71 `as_subclass` is implemented with a call to `THPVariable_NewWithVar` that will eventually call `tp_alloc` here https://github.com/pytorch/pytorch/blob/4e66aaa01092ddc8822bbca315b673329c76f4cd/torch/csrc/autograd/python_variable.cpp#L2053 The `func` arg to `_rebuild_from_type_v2` for wrapper subclasses is `Tensor.rebuild_wrapper_subclass`, which will similarly call into `THPVariable_NewWithVar` and hit the above `tp_alloc` **Note that we do not call `tp_init` or `tp_new` (i.e. `cls.__init__` or `cls.__new__`) when unpickling** ### How do we check something is a tensor subclass/constraints around imports In order to check whether `bla` is a tensor subclass in the bytecode `GLOBAL module.name`, we need to do an `issubclass` check, which entails converting the global string to the appropriate type. We *do not* arbitrarily import modules but will perform this check as long as the given subclass (given by `module.name`) has already been imported by the user (i.e. `module in sys.modules` and `issubclass(getattr(sys[modules], name), torch.Tensor)` This PR also allowlisted `torch._utils._rebuild_wrapper_subclass` and `torch.device` (used by `_rebuild_wrapper_subclass`) ### API for allow listing This PR also added `torch.serialization.mark_safe_globals` that enables user to allowlist globals they have deemed safe (for example they could allowlist a tensor subclass with a custom `__setstate__` if they have checked that this is safe. Next steps: - Add testing and allowlist required classes for all in-core tensor subclasses (e.g. `DTensor`, `FakeTensor` etc.) [ghstack-poisoned]
…safe_globals` that allows users to allowlist classes for `weights_only` load" #### Conditions for allowlisting tensor subclasses We allow tensor subclasses types that (1) Do not override `__setstate__`, `__getattr__`, `__setattr__`, `__get__`, `__set__` or `__getattribute__` of `torch.Tensor` (`torch.Tensor` does not have a definition of `__getattr__`, `__get__` or `__set__` so we check that these are `None`) (2) Use the generic `tp_alloc` (3) Are in a module that *has been imported by the user* to be pushed onto the stack by `GLOBAL` instructions. Further, we only allow these types to be used as the second argument of `_rebuild_from_type_v2` *Note that we use `inspect.getattr_static(sys.modules[module], name)` to get the class/function as this method claims to have no code execution. The rationale for the 3 conditions above is as follows: The rebuild func provided by `Tensor.__reduce_ex__` is `torch._tensor._rebuild_from_type_v2`, which is defined as such (note the call to `getattr`, `Tensor.__setstate__` and the call to `as_subclass` as well as the call to `_set_obj_state` which calls `setattr`) https://github.com/pytorch/pytorch/blob/4e66aaa01092ddc8822bbca315b673329c76f4cd/torch/_tensor.py#L57-L71 `as_subclass` is implemented with a call to `THPVariable_NewWithVar` that will eventually call `tp_alloc` here https://github.com/pytorch/pytorch/blob/4e66aaa01092ddc8822bbca315b673329c76f4cd/torch/csrc/autograd/python_variable.cpp#L2053 The `func` arg to `_rebuild_from_type_v2` for wrapper subclasses is `Tensor.rebuild_wrapper_subclass`, which will similarly call into `THPVariable_NewWithVar` and hit the above `tp_alloc` **Note that we do not call `tp_init` or `tp_new` (i.e. `cls.__init__` or `cls.__new__`) when unpickling** ### How do we check something is a tensor subclass/constraints around imports In order to check whether `bla` is a tensor subclass in the bytecode `GLOBAL module.name`, we need to do an `issubclass` check, which entails converting the global string to the appropriate type. We *do not* arbitrarily import modules but will perform this check as long as the given subclass (given by `module.name`) has already been imported by the user (i.e. `module in sys.modules` and `issubclass(getattr(sys[modules], name), torch.Tensor)` This PR also allowlisted `torch._utils._rebuild_wrapper_subclass` and `torch.device` (used by `_rebuild_wrapper_subclass`) ### API for allow listing This PR also added `torch.serialization.{mark/get/clear}_safe_globals` that enables user to allowlist globals they have deemed safe and manipulate this list (for example they could allowlist a tensor subclass with a custom `__setstate__` if they have checked that this is safe. Next steps: - Add testing and allowlist required classes for all in-core tensor subclasses (e.g. `DTensor`, `FakeTensor` etc.) [ghstack-poisoned]
…bclasses ghstack-source-id: 526d31626de8325fa811562346b6990c1b57e8b9 Pull Request resolved: #124331
torch.serialization.mark_safe_globals
that allows users to allowlist classes for weights_only
loadtorch.serialization.add_safe_globals
that allows users to allowlist classes for weights_only
load
…bclasses ghstack-source-id: f93fb3d7cdc5a40fb61581be41a490fc229f8d0f Pull Request resolved: #124331
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good!
torch/_weights_only_unpickler.py
Outdated
f"Found GLOBAL `{full_path}` instruction in the pickle file but `{full_path}` was " | ||
f"not in the pre-defined list of allowed globals that are considered safe by the " | ||
"weights_only unpickler for rebuilding state_dicts. This is the expected behavior if " | ||
f"`{full_path}` is a user-defined tensor subclass not defined in the `torch` package. " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This code will be called for any class that is not allowed for which the parent module is not already imported right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you're right! edited a bit to reflect this case, is the new phrasing clear
torch/_weights_only_unpickler.py
Outdated
f"If this is the case, we expect `{module}` to be present in `sys.modules` (i.e. it " | ||
"must be imported in the current environment), but this was not the case. " | ||
f"If you intend to unpickle a `{full_path}` please import `{name}` from `{module}`. " | ||
f"Note that having this imported will *only* allow the type `{full_path}` to be passed " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should have a docs page that go over this and point to it in the end.
This is ok temporarily while we're getting this landing page.
torch/_weights_only_unpickler.py
Outdated
arg_is_subclass_type = [ | ||
i | ||
for i, arg in enumerate(args) | ||
if isinstance(arg, type) and arg in self.tensor_subclasses_found |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if isinstance(arg, type) and arg in self.tensor_subclasses_found | |
if isinstance(arg, type) and issubclass(arg, torch.Tensor) |
Why not this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a valid point, but to be extra safe, I think that we should actually never let the subclass type be on the unpickler's stack so in the most recent push I changed the strategy here to
(1) GLOBAL
finds the subclass, if it meets the criteria, it pushes the str "module.attr" onto the unpickler's stack and adds self.tensor_subclasses_found["module.attr"] = module.attr
(2) when executing REDUCE
for rebuild_from_type_v2
, we properly translate the str back to the actual type
Does that sound reasonable to you?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sound good.
Tests import semantic for tensor subclass and the {add/get/clear}_safe_globals APIs | ||
''' | ||
# Needed to prevent UnboundLocalError: local variable 'TwoTensor' referenced before assignment | ||
global TwoTensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Huh why?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is weird to me as well, but seems to be necessary because TwoTensor
is imported again later in the test, can minimize this to this will still throw the error
from torch.testing._internal.two_tensor import TwoTensor
def test_safe_globals_for_weights_only(self):
t = TwoTensor(torch.randn(2, 3), torch.randn(2, 3))
from torch.testing._internal.two_tensor import TwoTensor
raises
t = TwoTensor(torch.randn(2, 3), torch.randn(2, 3))
UnboundLocalError: local variable 'TwoTensor' referenced before assignment
…bclasses ghstack-source-id: 6f67cc122c532bf6e52240ba1f009a0377ca9a09 Pull Request resolved: #124331
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good!
…bclasses ghstack-source-id: fbd63e6558d356424759f71513ba394c05e34eaa Pull Request resolved: #124331
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…` that allows users to allowlist classes for `weights_only` load (pytorch#124331) #### Conditions for allowlisting tensor subclasses We allow tensor subclasses types that (1) Do not override `__setstate__`, `__getattr__`, `__setattr__`, `__get__`, `__set__` or `__getattribute__` of `torch.Tensor` (`torch.Tensor` does not have a definition of `__getattr__`, `__get__` or `__set__` so we check that these are `None`) (2) Use the generic `tp_alloc` (3) Are in a module that *has been imported by the user* to be pushed onto the stack as strings by `GLOBAL` instructions, while storing the type in a dict The strings will be converted to the classes as appropriate when executing `REBUILD` with `_rebuild_from_type_v2` *Note that we use `inspect.getattr_static(sys.modules[module], name)` to get the class/function as this method claims to have no code execution. The rationale for the 3 conditions above is as follows: The rebuild func provided by `Tensor.__reduce_ex__` is `torch._tensor._rebuild_from_type_v2`, which is defined as such (note the call to `getattr`, `Tensor.__setstate__` and the call to `as_subclass` as well as the call to `_set_obj_state` which calls `setattr`) https://github.com/pytorch/pytorch/blob/4e66aaa01092ddc8822bbca315b673329c76f4cd/torch/_tensor.py#L57-L71 `as_subclass` is implemented with a call to `THPVariable_NewWithVar` that will eventually call `tp_alloc` here https://github.com/pytorch/pytorch/blob/4e66aaa01092ddc8822bbca315b673329c76f4cd/torch/csrc/autograd/python_variable.cpp#L2053 The `func` arg to `_rebuild_from_type_v2` for wrapper subclasses is `Tensor.rebuild_wrapper_subclass`, which will similarly call into `THPVariable_NewWithVar` and hit the above `tp_alloc` **Note that we do not call `tp_init` or `tp_new` (i.e. `cls.__init__` or `cls.__new__`) when unpickling** ### How do we check something is a tensor subclass/constraints around imports In order to check whether `bla` is a tensor subclass in the bytecode `GLOBAL module.name`, we need to do an `issubclass` check, which entails converting the global string to the appropriate type. We *do not* arbitrarily import modules but will perform this check as long as the given subclass (given by `module.name`) has already been imported by the user (i.e. `module in sys.modules` and `issubclass(getattr(sys[modules], name), torch.Tensor)` This PR also allowlisted `torch._utils._rebuild_wrapper_subclass` and `torch.device` (used by `_rebuild_wrapper_subclass`) ### API for allow listing This PR also added `torch.serialization.{add/get/clear}_safe_globals` that enables user to allowlist globals they have deemed safe and manipulate this list (for example they could allowlist a tensor subclass with a custom `__setstate__` if they have checked that this is safe). Next steps: - Add testing and allowlist required classes for all in-core tensor subclasses (e.g. `DTensor`, `FakeTensor` etc.) Pull Request resolved: pytorch#124331 Approved by: https://github.com/albanD
…127808) Remove logic to auto-detect and allow subclasses that did not override certain methods from the weights_only unpickler from #124331 for 2.4 release Subclasses should be loadable using `torch.serialization.add_safe_globals` Pull Request resolved: #127808 Approved by: https://github.com/malfet
Conditions for allowlisting tensor subclasses
We allow tensor subclasses types that
(1) Do not override
__setstate__
,__getattr__
,__setattr__
,__get__
,__set__
or__getattribute__
oftorch.Tensor
(torch.Tensor
does not have a definition of__getattr__
,__get__
or__set__
so we check that these areNone
)(2) Use the generic
tp_alloc
(3) Are in a module that has been imported by the user
to be pushed onto the stack as strings by
GLOBAL
instructions, while storing the type in a dictThe strings will be converted to the classes as appropriate when executing
REBUILD
with_rebuild_from_type_v2
*Note that we use
inspect.getattr_static(sys.modules[module], name)
to get the class/function as this method claims to have no code execution.The rationale for the 3 conditions above is as follows:
The rebuild func provided by
Tensor.__reduce_ex__
istorch._tensor._rebuild_from_type_v2
, which is defined as such (note the call togetattr
,Tensor.__setstate__
and the call toas_subclass
as well as the call to_set_obj_state
which callssetattr
)pytorch/torch/_tensor.py
Lines 57 to 71 in 4e66aaa
as_subclass
is implemented with a call toTHPVariable_NewWithVar
that will eventually call
tp_alloc
herepytorch/torch/csrc/autograd/python_variable.cpp
Line 2053 in 4e66aaa
The
func
arg to_rebuild_from_type_v2
for wrapper subclasses isTensor.rebuild_wrapper_subclass
, which will similarly call intoTHPVariable_NewWithVar
and hit the abovetp_alloc
Note that we do not call
tp_init
ortp_new
(i.e.cls.__init__
orcls.__new__
) when unpicklingHow do we check something is a tensor subclass/constraints around imports
In order to check whether
bla
is a tensor subclass in the bytecodeGLOBAL module.name
, we need to do anissubclass
check, which entails converting the global string to the appropriate type. We do not arbitrarily import modules but will perform this check as long as the given subclass (given bymodule.name
) has already been imported by the user (i.e.module in sys.modules
andissubclass(getattr(sys[modules], name), torch.Tensor)
This PR also allowlisted
torch._utils._rebuild_wrapper_subclass
andtorch.device
(used by_rebuild_wrapper_subclass
)API for allow listing
This PR also added
torch.serialization.{add/get/clear}_safe_globals
that enables user to allowlist globals they have deemed safe and manipulate this list (for example they could allowlist a tensor subclass with a custom__setstate__
if they have checked that this is safe).Next steps:
DTensor
,FakeTensor
etc.)Stack from ghstack (oldest at bottom):
torch.serialization.add_safe_globals
that allows users to allowlist classes forweights_only
load #124331