Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

repeated warning UserWarning: TypedStorage is deprecated #97207

Closed
hchau630 opened this issue Mar 21, 2023 · 7 comments
Closed

repeated warning UserWarning: TypedStorage is deprecated #97207

hchau630 opened this issue Mar 21, 2023 · 7 comments
Assignees
Labels
module: python frontend For issues relating to PyTorch's Python frontend triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@hchau630
Copy link

hchau630 commented Mar 21, 2023

馃悰 Describe the bug

I noticed that the warning UserWarning: TypedStorage is deprecated that results from calling Tensor.storage() is issued repeatedly, even though the documentation about torch.set_warn_always says that by default pytorch warnings should only be issued once per process. While modifying my code to call untyped_storage() instead of storage() could get rid of the excessive warnings, I would prefer not to since I want to keep my code compatible with pre-2.0.0 versions for now, and untyped_storage() is only available >=2.0.0.

Example:

import torch

torch.tensor([1,2,3]).storage()
torch.tensor([1,2,3]).storage()
torch.tensor([1,2,3]).storage()

Output:

/path_to_file/test_warning.py:3: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  torch.tensor([1,2,3]).storage()
/path_to_file/test_warning.py:4: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  torch.tensor([1,2,3]).storage()
/path_to_file/test_warning.py:5: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  torch.tensor([1,2,3]).storage()

Versions

Collecting environment information...
PyTorch version: 2.0.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 13.2.1 (arm64)
GCC version: Could not collect
Clang version: 12.0.0 (clang-1200.0.32.28)
CMake version: version 3.26.0
Libc version: N/A

Python version: 3.9.16 (main, Mar  1 2023, 12:19:04)  [Clang 14.0.6 ] (64-bit runtime)
Python platform: macOS-13.2.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M1 Pro

Versions of relevant libraries:
[pip3] numpy==1.24.2
[pip3] torch==2.0.0
[conda] numpy                     1.24.2                   pypi_0    pypi
[conda] torch                     2.0.0                    pypi_0    pypi

cc @albanD

@albanD albanD added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: python frontend For issues relating to PyTorch's Python frontend labels Mar 21, 2023
@albanD
Copy link
Collaborator

albanD commented Mar 21, 2023

cc @kurtamohler could you fix that?

As a temporary workaround, you can silent this warning with regular python warning filter: python -W ignore::UserWarning: foo.py

@kurtamohler kurtamohler self-assigned this Mar 22, 2023
@kurtamohler
Copy link
Collaborator

The complication is that the warning message contains the line of code that triggered the warning, with warnings.warn(..., stacklevel=...):

pytorch/torch/storage.py

Lines 365 to 372 in 0b094ca

def _warn_typed_storage_removal(stacklevel=2):
message = (
"TypedStorage is deprecated. It will be removed in the future and "
"UntypedStorage will be the only storage class. This should only matter "
"to you if you are using storages directly. To access UntypedStorage "
"directly, use tensor.untyped_storage() instead of tensor.storage()"
)
warnings.warn(message, UserWarning, stacklevel=stacklevel + 1)

So one option is to remove the stacklevel arg. But having the stack in the message is helpful for finding where the warning came from (#89867)

Another option is to only emit the warning the first time _warn_typed_storage_removal() gets called. For testing purposes, we could have a flag to make it always warn. Does that seem reasonable?

cc @ezyang

@ezyang
Copy link
Contributor

ezyang commented Mar 22, 2023

Yeah, let's just have this warn only once and never again. If you're trying to scrub these, you can comment it out

@lorinczszabolcs
Copy link

Hi! Was this issue already solved? I tried to figure it out based on the PRs, but it seems none of them have been merged / released. Is there any suggested way of silencing it? Thanks!

@kurtamohler
Copy link
Collaborator

kurtamohler commented Apr 3, 2023

@lorinczszabolcs, #97379 has been merged to fix this issue, but it hasn't been included in a new release yet

Until it's released, you could use the workaround mentioned here: #97207 (comment)

But that will silence all UserWarnings. If you only want to silence the TypedStorage deprecation warnings, you'd have to do something like this:

import warnings
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')

kurtamohler added a commit to kurtamohler/pytorch that referenced this issue Apr 8, 2023
@RodolpheCalvet
Copy link

Hi there,
I was wondering about the implications/measures to take. Is it normal behavior, is pytorch gracefuly reverting to the new UntypedStorage class or not?
WHile running a streamlit web server locally I am good, but that's one of my currently possible reason for crash when deploying the app on streamlit.io.
Got torch==2.0.0 from pip...
Let me know if more info are relevant, tks:

@kurtamohler
Copy link
Collaborator

@RodolpheCalvet, it sounds like you're experiencing a crash related to UntypedStorage/TypedStorage. Could you please open a new issue and cc me on it? https://github.com/pytorch/pytorch/issues/new?assignees=&labels=&template=bug-report.yml

atalman pushed a commit that referenced this issue Apr 18, 2023
* Only warn once for TypedStorage deprecation (#97379)

Fixes #97207

Pull Request resolved: #97379
Approved by: https://github.com/ezyang

* Specify file encoding in test_torch.py (#97628)

Attempt to fix
```
UnicodeDecodeError: 'ascii' codec can't decode byte 0xe4 in position 5260: ordinal not in range(128)
```
in https://github.com/pytorch/pytorch/actions/runs/4522628359/jobs/7965372405

In general, it's a good practice to explicitly specify encoding, as otherwise it depends on environment variable and makes tests failures unpredicatble

Pull Request resolved: #97628
Approved by: https://github.com/dagitses, https://github.com/kit1980

---------

Co-authored-by: Nikita Shulga <nshulga@meta.com>
Xallt added a commit to Xallt/stable-dreamfusion that referenced this issue Aug 22, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: python frontend For issues relating to PyTorch's Python frontend triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants