-
Notifications
You must be signed in to change notification settings - Fork 21.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
python3 -OO option does not work #76034
Comments
My understanding is that the In this case, I think we can simply check if the |
I can contribute here to this. Not sure how much improvement it would bring but I can check that. I had a quick look at how doc is being used:
I see some asserts in test but I guess that would be irrelevant here. I will also check if dependencies are compatible. What do you think? |
Thanks for looking into this. For testing, I think it is safe to assume that testing does not run in optimized mode. pytorch/test/test_stateless.py Lines 159 to 179 in ea19016
|
Cool. Just one thing, for 2 removing class Works:
"""I won't stay with -OO :)"""
class Problem:
"""I am not a docstring anymore {}""".format(":(") Problem: pytorch/torch/nn/modules/conv.py Lines 174 to 233 in 04b3313
So we might have to keep it as it is and then after declaration do the formatting if class WorkAround:
"""I am a {} docstring that won't stay with -OO :)"""
if WorkAround.__doc__:
WorkAround.__doc__ = WorkAround.__doc__.format("formatted") |
That sounds like a good plan! A bit unfortunate that we have to do that but hey! |
Created: #76619. Still there's a lot of changes pending. Should I make it a draft PR? |
Thanks for the fix! Also I would be curious to discuss there the impact of the remaining docstrings. |
Created #76659 |
Summary: Fixes #76034 This does not make python remove all `__doc__` because in some places `__doc__` is assigned to a string. Example: https://github.com/pytorch/pytorch/blob/04b3313379712098183dfe5bea002a5e43b5af48/torch/nn/modules/conv.py#L174-L233 Since there are quite a few of these, I will add all of them together in this PR later. (Basically still a lot of docstring will persist even with `-OO` enabled.) Pull Request resolved: #76619 Approved by: https://github.com/albanD Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/f92cddd89030af9389f20d7c75c0831efbf9a2ba Reviewed By: malfet Differential Revision: D36073490 fbshipit-source-id: 49efc08a09facefa705fc25d1ad2bd368d86ca28
馃悰 Describe the bug
My scripts depend on PyTorch, although I do not use it at the moment.
When I try optimization options, the python3 -OO ... fails but python3 -O ... is fine.
This was checked on MacOS and RedHat 7, Anaconda environment 4.12, python 3.9.12, pytorch 1.10.2.
The output:
Versions
Collecting environment information...
PyTorch version: 1.10.2
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 12.3.1 (x86_64)
GCC version: Could not collect
Clang version: 13.1.6 (clang-1316.0.21.2.3)
CMake version: Could not collect
Libc version: N/A
Python version: 3.9.12 (main, Apr 5 2022, 01:53:17) [Clang 12.0.0 ] (64-bit runtime)
Python platform: macOS-10.16-x86_64-i386-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Versions of relevant libraries:
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.22.2
[pip3] torch==1.10.2
[conda] blas 1.0 mkl
[conda] libblas 3.9.0 13_osx64_mkl conda-forge
[conda] libcblas 3.9.0 13_osx64_mkl conda-forge
[conda] liblapack 3.9.0 13_osx64_mkl conda-forge
[conda] mkl 2022.0.0 hecd8cb5_105
[conda] mypy-extensions 0.4.3 pypi_0 pypi
[conda] numpy 1.22.2 py39h9d9ce41_0 conda-forge
[conda] pytorch 1.10.2 py3.9_0 pytorch
cc @albanD @mruberry @jbschlosser @walterddr @kshitij12345
The text was updated successfully, but these errors were encountered: