Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow tensor subclasses and add torch.serialization.add_safe_globals that allows users to allowlist classes for weights_only load #124331

Conversation

mikaylagawarecki
Copy link
Contributor

@mikaylagawarecki mikaylagawarecki commented Apr 17, 2024

Conditions for allowlisting tensor subclasses

We allow tensor subclasses types that
(1) Do not override __setstate__, __getattr__, __setattr__, __get__, __set__ or __getattribute__ of torch.Tensor (torch.Tensor does not have a definition of __getattr__, __get__ or __set__ so we check that these are None)
(2) Use the generic tp_alloc
(3) Are in a module that has been imported by the user
to be pushed onto the stack as strings by GLOBAL instructions, while storing the type in a dict

The strings will be converted to the classes as appropriate when executing REBUILD with _rebuild_from_type_v2

*Note that we use inspect.getattr_static(sys.modules[module], name) to get the class/function as this method claims to have no code execution.

The rationale for the 3 conditions above is as follows:

The rebuild func provided by Tensor.__reduce_ex__ is torch._tensor._rebuild_from_type_v2, which is defined as such (note the call to getattr, Tensor.__setstate__ and the call to as_subclass as well as the call to _set_obj_state which calls setattr)

pytorch/torch/_tensor.py

Lines 57 to 71 in 4e66aaa

def _rebuild_from_type_v2(func, new_type, args, state):
ret = func(*args)
if type(ret) is not new_type:
ret = ret.as_subclass(new_type)
# Tensor does define __setstate__ even though it doesn't define
# __getstate__. So only use __setstate__ if it is NOT the one defined
# on Tensor
if (
getattr(ret.__class__, "__setstate__", Tensor.__setstate__)
is not Tensor.__setstate__
):
ret.__setstate__(state)
else:
ret = torch._utils._set_obj_state(ret, state)
return ret

as_subclass is implemented with a call to THPVariable_NewWithVar

that will eventually call tp_alloc here

PyObject* obj = type->tp_alloc(type, 0);

The func arg to _rebuild_from_type_v2 for wrapper subclasses is Tensor.rebuild_wrapper_subclass, which will similarly call into THPVariable_NewWithVar and hit the above tp_alloc

Note that we do not call tp_init or tp_new (i.e. cls.__init__ or cls.__new__) when unpickling

How do we check something is a tensor subclass/constraints around imports

In order to check whether bla is a tensor subclass in the bytecode GLOBAL module.name, we need to do an issubclass check, which entails converting the global string to the appropriate type. We do not arbitrarily import modules but will perform this check as long as the given subclass (given by module.name) has already been imported by the user (i.e. module in sys.modules and issubclass(getattr(sys[modules], name), torch.Tensor)

This PR also allowlisted torch._utils._rebuild_wrapper_subclass and torch.device (used by _rebuild_wrapper_subclass)

API for allow listing

This PR also added torch.serialization.{add/get/clear}_safe_globals that enables user to allowlist globals they have deemed safe and manipulate this list (for example they could allowlist a tensor subclass with a custom __setstate__ if they have checked that this is safe).

Next steps:

  • Add testing and allowlist required classes for all in-core tensor subclasses (e.g. DTensor, FakeTensor etc.)

Stack from ghstack (oldest at bottom):

Copy link

pytorch-bot bot commented Apr 17, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/124331

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit e91a1de with merge base 8619fe6 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

mikaylagawarecki added a commit that referenced this pull request Apr 17, 2024
…bclasses

ghstack-source-id: 01d0789c216764e8105c6caf5fcc1b775db35af3
Pull Request resolved: #124331
@mikaylagawarecki mikaylagawarecki changed the title Add API for users to allowlist classes for weights_only for Tensor Subclasses Add torch.serialization.mark_safe_globals that allows users to allowlist classes for weights_only Apr 18, 2024
@mikaylagawarecki mikaylagawarecki changed the title Add torch.serialization.mark_safe_globals that allows users to allowlist classes for weights_only Add torch.serialization.mark_safe_globals that allows users to allowlist classes for weights_only load Apr 18, 2024
…ers to allowlist classes for `weights_only` load"

[ghstack-poisoned]
@mikaylagawarecki mikaylagawarecki marked this pull request as ready for review April 18, 2024 17:56
@mikaylagawarecki mikaylagawarecki added release notes: python_frontend python frontend release notes category topic: new features topic category labels Apr 18, 2024
…ers to allowlist classes for `weights_only` load"

Allowlisted  `torch._utils._rebuild_wrapper_subclass` and `torch.device` (used by `_rebuild_wrapper_subclass`) and added `torch.serialization.mark_safe_globals` that enables user allowlisting of classes they have deemed safe + simple test with `TwoTensor`.

This API is primarily intended to be used for checkpoints where user-defined Tensor Subclasses have been serialized.

Next steps:
- Add testing and allowlist required classes for all in-core tensor subclasses (e.g. `DTensor`, `FakeTensor` etc.)




[ghstack-poisoned]
mikaylagawarecki added a commit that referenced this pull request Apr 18, 2024
…bclasses

ghstack-source-id: 5dfbbad41105e5f9d4e89f2e227e46ecf3b7b732
Pull Request resolved: #124331
@mikaylagawarecki mikaylagawarecki marked this pull request as draft April 18, 2024 20:28
Copy link
Contributor Author

@mikaylagawarecki mikaylagawarecki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Marking as draft after discussion with Alban -- while we will likely still need a way for users to allowlist specific classes, want to also look into whether there is a generic way to allowlist tensor subclasses (with appropriate guardrails for security)

…ers to allowlist classes for `weights_only` load"

Allowlisted  `torch._utils._rebuild_wrapper_subclass` and `torch.device` (used by `_rebuild_wrapper_subclass`) and added `torch.serialization.mark_safe_globals` that enables user allowlisting of classes they have deemed safe + simple test with `TwoTensor`.

This API is primarily intended to be used for checkpoints where user-defined Tensor Subclasses have been serialized.

Next steps:
- Add testing and allowlist required classes for all in-core tensor subclasses (e.g. `DTensor`, `FakeTensor` etc.)




[ghstack-poisoned]
else:
raise RuntimeError(f"Unsupported class {full_path}")
class_type = getattr(modules[module], name)
Copy link
Contributor Author

@mikaylagawarecki mikaylagawarecki Apr 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@albanD I think there might be a fundamental problem here that prevents us from allowlisting all tensor subclasses.

The GLOBAL <bla> in the pickle file will have as a string, in order to check issubclass(bla, torch.Tensor), we need to convert the string bla to the actual type object. To my knowledge, I could not find a safe way to do this, and I suspect this line here might not be what we consider secure. wdyt?

Is there a way where we can constrain this such that it is still secure?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From offline discussion, a good middle ground is to put the responsibility to import the module on the user

…ers to allowlist classes for `weights_only` load"

Allowlisted  `torch._utils._rebuild_wrapper_subclass` and `torch.device` (used by `_rebuild_wrapper_subclass`) and added `torch.serialization.mark_safe_globals` that enables user allowlisting of classes they have deemed safe + simple test with `TwoTensor`.

This API is primarily intended to be used for checkpoints where user-defined Tensor Subclasses have been serialized.

Next steps:
- Add testing and allowlist required classes for all in-core tensor subclasses (e.g. `DTensor`, `FakeTensor` etc.)




[ghstack-poisoned]
mikaylagawarecki added a commit that referenced this pull request Apr 22, 2024
…bclasses

ghstack-source-id: 467f6c37d45a486ad9f5a17faa0bdf67f990ad76
Pull Request resolved: #124331
…ers to allowlist classes for `weights_only` load"

Allowlisted  `torch._utils._rebuild_wrapper_subclass` and `torch.device` (used by `_rebuild_wrapper_subclass`) and added `torch.serialization.mark_safe_globals` that enables user allowlisting of classes they have deemed safe + simple test with `TwoTensor`.

This API is primarily intended to be used for checkpoints where user-defined Tensor Subclasses have been serialized.

Next steps:
- Add testing and allowlist required classes for all in-core tensor subclasses (e.g. `DTensor`, `FakeTensor` etc.)




[ghstack-poisoned]
@mikaylagawarecki mikaylagawarecki changed the title Add torch.serialization.mark_safe_globals that allows users to allowlist classes for weights_only load Allow tensor subclasses and add torch.serialization.mark_safe_globals that allows users to allowlist classes for weights_only load Apr 23, 2024
@mikaylagawarecki mikaylagawarecki marked this pull request as ready for review April 23, 2024 18:38
…safe_globals` that allows users to allowlist classes for `weights_only` load"

#### Conditions for allowlisting tensor subclasses
We allowlist tensor subclasses that 
(1) Do not override `__setstate__`
(2) Use the generic `tp_alloc`

The rationale for these two conditions is as follows:

The rebuild func provided by `Tensor.__reduce_ex__` is `torch._tensor._rebuild_from_type_v2`, which is defined as such (note the call to `Tensor.__setstate__` and the call to `as_subclass`)

https://github.com/pytorch/pytorch/blob/4e66aaa01092ddc8822bbca315b673329c76f4cd/torch/_tensor.py#L57-L71

`as_subclass` is implemented with a call to `THPVariable_NewWithVar`

that will eventually call `tp_alloc` here
https://github.com/pytorch/pytorch/blob/4e66aaa01092ddc8822bbca315b673329c76f4cd/torch/csrc/autograd/python_variable.cpp#L2053


The `func` arg to `_rebuild_from_type_v2` for wrapper subclasses is `Tensor.rebuild_wrapper_subclass`, which will similarly call into `THPVariable_NewWithVar` and hit the above `tp_alloc`

**Note that we do not call `tp_init` or `tp_new` (i.e. `cls.__init__` or `cls.__new__`) when unpickling**


### How do we check something is a tensor subclass/constraints around imports

In order to check whether `bla` is a tensor subclass in the bytecode `GLOBAL module.name`, we need to do an `issubclass` check, which entails converting the global string to the appropriate type. We *do not* arbitrarily import modules but will perform this check as long as the given subclass (given by `module.name`) has already been imported by the user (i.e. `module in sys.modules` and `issubclass(getattr(sys[modules], name), torch.Tensor)` 

This PR also allowlisted  `torch._utils._rebuild_wrapper_subclass` and `torch.device` (used by `_rebuild_wrapper_subclass`) 

### API for allow listing 
This PR also added `torch.serialization.mark_safe_globals` that enables user to allowlist globals they have deemed safe (for example they could allowlist a tensor subclass with a custom `__setstate__` if they have checked that this is safe.

Next steps:
- Add testing and allowlist required classes for all in-core tensor subclasses (e.g. `DTensor`, `FakeTensor` etc.)




[ghstack-poisoned]
…safe_globals` that allows users to allowlist classes for `weights_only` load"

#### Conditions for allowlisting tensor subclasses
We allowlist tensor subclasses that 
(1) Do not override `__setstate__`
(2) Use the generic `tp_alloc`

The rationale for these two conditions is as follows:

The rebuild func provided by `Tensor.__reduce_ex__` is `torch._tensor._rebuild_from_type_v2`, which is defined as such (note the call to `Tensor.__setstate__` and the call to `as_subclass`)

https://github.com/pytorch/pytorch/blob/4e66aaa01092ddc8822bbca315b673329c76f4cd/torch/_tensor.py#L57-L71

`as_subclass` is implemented with a call to `THPVariable_NewWithVar`

that will eventually call `tp_alloc` here
https://github.com/pytorch/pytorch/blob/4e66aaa01092ddc8822bbca315b673329c76f4cd/torch/csrc/autograd/python_variable.cpp#L2053


The `func` arg to `_rebuild_from_type_v2` for wrapper subclasses is `Tensor.rebuild_wrapper_subclass`, which will similarly call into `THPVariable_NewWithVar` and hit the above `tp_alloc`

**Note that we do not call `tp_init` or `tp_new` (i.e. `cls.__init__` or `cls.__new__`) when unpickling**


### How do we check something is a tensor subclass/constraints around imports

In order to check whether `bla` is a tensor subclass in the bytecode `GLOBAL module.name`, we need to do an `issubclass` check, which entails converting the global string to the appropriate type. We *do not* arbitrarily import modules but will perform this check as long as the given subclass (given by `module.name`) has already been imported by the user (i.e. `module in sys.modules` and `issubclass(getattr(sys[modules], name), torch.Tensor)` 

This PR also allowlisted  `torch._utils._rebuild_wrapper_subclass` and `torch.device` (used by `_rebuild_wrapper_subclass`) 

### API for allow listing 
This PR also added `torch.serialization.mark_safe_globals` that enables user to allowlist globals they have deemed safe (for example they could allowlist a tensor subclass with a custom `__setstate__` if they have checked that this is safe.

Next steps:
- Add testing and allowlist required classes for all in-core tensor subclasses (e.g. `DTensor`, `FakeTensor` etc.)




[ghstack-poisoned]
…safe_globals` that allows users to allowlist classes for `weights_only` load"

#### Conditions for allowlisting tensor subclasses
We allow tensor subclasses types that 
(1) Do not override `__setstate__`, `__getattr__`, `__setattr__`, `__get__`, `__set__` or `__getattribute__` of `torch.Tensor` (`torch.Tensor` does not have a definition of `__getattr__`, `__get__` or `__set__` so we check that these are `None`)
(2) Use the generic `tp_alloc`
(3) Are in a module that *has been imported by the user*
to be pushed onto the stack by `GLOBAL` instructions.

Further, we only allow these types to be used as the second argument of `_rebuild_from_type_v2`

*Note that we use `inspect.getattr_static(sys.modules[module], name)` to get the class/function as this method claims to have no code execution.

The rationale for the 3 conditions above is as follows:

The rebuild func provided by `Tensor.__reduce_ex__` is `torch._tensor._rebuild_from_type_v2`, which is defined as such (note the call to `getattr`, `Tensor.__setstate__` and the call to `as_subclass` as well as the call to `_set_obj_state` which calls `setattr`)

https://github.com/pytorch/pytorch/blob/4e66aaa01092ddc8822bbca315b673329c76f4cd/torch/_tensor.py#L57-L71

`as_subclass` is implemented with a call to `THPVariable_NewWithVar`

that will eventually call `tp_alloc` here
https://github.com/pytorch/pytorch/blob/4e66aaa01092ddc8822bbca315b673329c76f4cd/torch/csrc/autograd/python_variable.cpp#L2053


The `func` arg to `_rebuild_from_type_v2` for wrapper subclasses is `Tensor.rebuild_wrapper_subclass`, which will similarly call into `THPVariable_NewWithVar` and hit the above `tp_alloc`

**Note that we do not call `tp_init` or `tp_new` (i.e. `cls.__init__` or `cls.__new__`) when unpickling**


### How do we check something is a tensor subclass/constraints around imports

In order to check whether `bla` is a tensor subclass in the bytecode `GLOBAL module.name`, we need to do an `issubclass` check, which entails converting the global string to the appropriate type. We *do not* arbitrarily import modules but will perform this check as long as the given subclass (given by `module.name`) has already been imported by the user (i.e. `module in sys.modules` and `issubclass(getattr(sys[modules], name), torch.Tensor)` 

This PR also allowlisted  `torch._utils._rebuild_wrapper_subclass` and `torch.device` (used by `_rebuild_wrapper_subclass`) 

### API for allow listing 
This PR also added `torch.serialization.{mark/get/clear}_safe_globals` that enables user to allowlist globals they have deemed safe and manipulate this list (for example they could allowlist a tensor subclass with a custom `__setstate__` if they have checked that this is safe.

Next steps:
- Add testing and allowlist required classes for all in-core tensor subclasses (e.g. `DTensor`, `FakeTensor` etc.)




[ghstack-poisoned]
mikaylagawarecki added a commit that referenced this pull request May 3, 2024
…bclasses

ghstack-source-id: 526d31626de8325fa811562346b6990c1b57e8b9
Pull Request resolved: #124331
@mikaylagawarecki mikaylagawarecki changed the title Allow tensor subclasses and add torch.serialization.mark_safe_globals that allows users to allowlist classes for weights_only load Allow tensor subclasses and add torch.serialization.add_safe_globals that allows users to allowlist classes for weights_only load May 13, 2024
[ghstack-poisoned]
[ghstack-poisoned]
mikaylagawarecki added a commit that referenced this pull request May 13, 2024
…bclasses

ghstack-source-id: f93fb3d7cdc5a40fb61581be41a490fc229f8d0f
Pull Request resolved: #124331
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good!

torch/_weights_only_unpickler.py Show resolved Hide resolved
torch/_weights_only_unpickler.py Outdated Show resolved Hide resolved
torch/_weights_only_unpickler.py Outdated Show resolved Hide resolved
torch/_weights_only_unpickler.py Outdated Show resolved Hide resolved
f"Found GLOBAL `{full_path}` instruction in the pickle file but `{full_path}` was "
f"not in the pre-defined list of allowed globals that are considered safe by the "
"weights_only unpickler for rebuilding state_dicts. This is the expected behavior if "
f"`{full_path}` is a user-defined tensor subclass not defined in the `torch` package. "
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code will be called for any class that is not allowed for which the parent module is not already imported right?

Copy link
Contributor Author

@mikaylagawarecki mikaylagawarecki May 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you're right! edited a bit to reflect this case, is the new phrasing clear

f"If this is the case, we expect `{module}` to be present in `sys.modules` (i.e. it "
"must be imported in the current environment), but this was not the case. "
f"If you intend to unpickle a `{full_path}` please import `{name}` from `{module}`. "
f"Note that having this imported will *only* allow the type `{full_path}` to be passed "
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should have a docs page that go over this and point to it in the end.
This is ok temporarily while we're getting this landing page.

torch/_weights_only_unpickler.py Outdated Show resolved Hide resolved
torch/_weights_only_unpickler.py Outdated Show resolved Hide resolved
arg_is_subclass_type = [
i
for i, arg in enumerate(args)
if isinstance(arg, type) and arg in self.tensor_subclasses_found
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if isinstance(arg, type) and arg in self.tensor_subclasses_found
if isinstance(arg, type) and issubclass(arg, torch.Tensor)

Why not this?

Copy link
Contributor Author

@mikaylagawarecki mikaylagawarecki May 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a valid point, but to be extra safe, I think that we should actually never let the subclass type be on the unpickler's stack so in the most recent push I changed the strategy here to
(1) GLOBAL finds the subclass, if it meets the criteria, it pushes the str "module.attr" onto the unpickler's stack and adds self.tensor_subclasses_found["module.attr"] = module.attr
(2) when executing REDUCE for rebuild_from_type_v2, we properly translate the str back to the actual type

Does that sound reasonable to you?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sound good.

Tests import semantic for tensor subclass and the {add/get/clear}_safe_globals APIs
'''
# Needed to prevent UnboundLocalError: local variable 'TwoTensor' referenced before assignment
global TwoTensor
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Huh why?

Copy link
Contributor Author

@mikaylagawarecki mikaylagawarecki May 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is weird to me as well, but seems to be necessary because TwoTensor is imported again later in the test, can minimize this to this will still throw the error

from torch.testing._internal.two_tensor import TwoTensor

def test_safe_globals_for_weights_only(self):
    t = TwoTensor(torch.randn(2, 3), torch.randn(2, 3))
    from torch.testing._internal.two_tensor import TwoTensor

raises

t = TwoTensor(torch.randn(2, 3), torch.randn(2, 3))
UnboundLocalError: local variable 'TwoTensor' referenced before assignment

[ghstack-poisoned]
[ghstack-poisoned]
mikaylagawarecki added a commit that referenced this pull request May 15, 2024
…bclasses

ghstack-source-id: 6f67cc122c532bf6e52240ba1f009a0377ca9a09
Pull Request resolved: #124331
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good!

torch/_weights_only_unpickler.py Outdated Show resolved Hide resolved
[ghstack-poisoned]
mikaylagawarecki added a commit that referenced this pull request May 17, 2024
…bclasses

ghstack-source-id: fbd63e6558d356424759f71513ba394c05e34eaa
Pull Request resolved: #124331
@mikaylagawarecki
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label May 17, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

ZelboK pushed a commit to ZelboK/pytorch that referenced this pull request May 19, 2024
…` that allows users to allowlist classes for `weights_only` load (pytorch#124331)

#### Conditions for allowlisting tensor subclasses
We allow tensor subclasses types that
(1) Do not override `__setstate__`, `__getattr__`, `__setattr__`, `__get__`, `__set__` or `__getattribute__` of `torch.Tensor` (`torch.Tensor` does not have a definition of `__getattr__`, `__get__` or `__set__` so we check that these are `None`)
(2) Use the generic `tp_alloc`
(3) Are in a module that *has been imported by the user*
to be pushed onto the stack as strings by `GLOBAL` instructions, while storing the type in a dict

The strings will be converted to the classes as appropriate when executing `REBUILD` with `_rebuild_from_type_v2`

*Note that we use `inspect.getattr_static(sys.modules[module], name)` to get the class/function as this method claims to have no code execution.

The rationale for the 3 conditions above is as follows:

The rebuild func provided by `Tensor.__reduce_ex__` is `torch._tensor._rebuild_from_type_v2`, which is defined as such (note the call to `getattr`, `Tensor.__setstate__` and the call to `as_subclass` as well as the call to `_set_obj_state` which calls `setattr`)

https://github.com/pytorch/pytorch/blob/4e66aaa01092ddc8822bbca315b673329c76f4cd/torch/_tensor.py#L57-L71

`as_subclass` is implemented with a call to `THPVariable_NewWithVar`

that will eventually call `tp_alloc` here
https://github.com/pytorch/pytorch/blob/4e66aaa01092ddc8822bbca315b673329c76f4cd/torch/csrc/autograd/python_variable.cpp#L2053

The `func` arg to `_rebuild_from_type_v2` for wrapper subclasses is `Tensor.rebuild_wrapper_subclass`, which will similarly call into `THPVariable_NewWithVar` and hit the above `tp_alloc`

**Note that we do not call `tp_init` or `tp_new` (i.e. `cls.__init__` or `cls.__new__`) when unpickling**

### How do we check something is a tensor subclass/constraints around imports

In order to check whether `bla` is a tensor subclass in the bytecode `GLOBAL module.name`, we need to do an `issubclass` check, which entails converting the global string to the appropriate type. We *do not* arbitrarily import modules but will perform this check as long as the given subclass (given by `module.name`) has already been imported by the user (i.e. `module in sys.modules` and `issubclass(getattr(sys[modules], name), torch.Tensor)`

This PR also allowlisted  `torch._utils._rebuild_wrapper_subclass` and `torch.device` (used by `_rebuild_wrapper_subclass`)

### API for allow listing
This PR also added `torch.serialization.{add/get/clear}_safe_globals` that enables user to allowlist globals they have deemed safe and manipulate this list (for example they could allowlist a tensor subclass with a custom `__setstate__` if they have checked that this is safe).

Next steps:
- Add testing and allowlist required classes for all in-core tensor subclasses (e.g. `DTensor`, `FakeTensor` etc.)

Pull Request resolved: pytorch#124331
Approved by: https://github.com/albanD
pytorchmergebot pushed a commit that referenced this pull request Jun 5, 2024
…127808)

Remove logic to auto-detect and allow subclasses that did not override certain methods from the weights_only unpickler from #124331 for 2.4 release

Subclasses should be loadable using `torch.serialization.add_safe_globals`

Pull Request resolved: #127808
Approved by: https://github.com/malfet
@github-actions github-actions bot deleted the gh/mikaylagawarecki/196/head branch June 17, 2024 01:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: python_frontend python frontend release notes category topic: new features topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants