Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

python3 -OO option does not work #76034

Closed
A-tA-v opened this issue Apr 19, 2022 · 8 comments
Closed

python3 -OO option does not work #76034

A-tA-v opened this issue Apr 19, 2022 · 8 comments
Labels
actionable module: python frontend For issues relating to PyTorch's Python frontend triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@A-tA-v
Copy link

A-tA-v commented Apr 19, 2022

馃悰 Describe the bug

My scripts depend on PyTorch, although I do not use it at the moment.
When I try optimization options, the python3 -OO ... fails but python3 -O ... is fine.
This was checked on MacOS and RedHat 7, Anaconda environment 4.12, python 3.9.12, pytorch 1.10.2.
The output:

python3 -OO my_script
Traceback (most recent call last):
 ...................................
    import torch as th
  File "/Users/albert/miniconda3/envs/qu/lib/python3.9/site-packages/torch/__init__.py", line 643, in <module>
    from .functional import *  # noqa: F403
  File "/Users/albert/miniconda3/envs/qu/lib/python3.9/site-packages/torch/functional.py", line 6, in <module>
    import torch.nn.functional as F
  File "/Users/albert/miniconda3/envs/qu/lib/python3.9/site-packages/torch/nn/__init__.py", line 1, in <module>
    from .modules import *  # noqa: F403
  File "/Users/albert/miniconda3/envs/qu/lib/python3.9/site-packages/torch/nn/modules/__init__.py", line 2, in <module>
    from .linear import Identity, Linear, Bilinear, LazyLinear
  File "/Users/albert/miniconda3/envs/qu/lib/python3.9/site-packages/torch/nn/modules/linear.py", line 6, in <module>
    from .. import functional as F
  File "/Users/albert/miniconda3/envs/qu/lib/python3.9/site-packages/torch/nn/functional.py", line 2231, in <module>
    embedding_bag.__doc__ = embedding_bag.__doc__.format(**reproducibility_notes)
AttributeError: 'NoneType' object has no attribute 'format'

Versions

Collecting environment information...
PyTorch version: 1.10.2
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 12.3.1 (x86_64)
GCC version: Could not collect
Clang version: 13.1.6 (clang-1316.0.21.2.3)
CMake version: Could not collect
Libc version: N/A

Python version: 3.9.12 (main, Apr 5 2022, 01:53:17) [Clang 12.0.0 ] (64-bit runtime)
Python platform: macOS-10.16-x86_64-i386-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.22.2
[pip3] torch==1.10.2
[conda] blas 1.0 mkl
[conda] libblas 3.9.0 13_osx64_mkl conda-forge
[conda] libcblas 3.9.0 13_osx64_mkl conda-forge
[conda] liblapack 3.9.0 13_osx64_mkl conda-forge
[conda] mkl 2022.0.0 hecd8cb5_105
[conda] mypy-extensions 0.4.3 pypi_0 pypi
[conda] numpy 1.22.2 py39h9d9ce41_0 conda-forge
[conda] pytorch 1.10.2 py3.9_0 pytorch

cc @albanD @mruberry @jbschlosser @walterddr @kshitij12345

@anjali411 anjali411 added module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Apr 19, 2022
@albanD
Copy link
Collaborator

albanD commented Apr 19, 2022

My understanding is that the -OO option is removing all the docstring? So it is possible that a lot of our code fails today.
We would be happy to accept a PR fixing such issues to make sure you can run pytorch with -OO.

In this case, I think we can simply check if the __doc__ is None and only format it if it exists.

@albanD albanD added actionable module: python frontend For issues relating to PyTorch's Python frontend and removed module: nn Related to torch.nn labels Apr 19, 2022
@vitrioil
Copy link
Contributor

I can contribute here to this. Not sure how much improvement it would bring but I can check that.

I had a quick look at how doc is being used:

  1. Assigning to doc
    embedding_bag.__doc__ = embedding_bag.__doc__.format(**reproducibility_notes)

Adding None check should solve this

  1. Directly
    class Conv1d(_ConvNd):
    __doc__ = r"""Applies a 1D convolution over an input signal composed of several input

This will NOT be removed by python, this has to be programmatically set (because of format). (Maybe similar to 1. or some other way)

  1. Getting doc using getattr and performing operations on it
    docstring = getattr(method, "__doc__", None)
    assert docstring is not None, "RRef user-facing methods should all have docstrings."
    # Do surgery on pybind11 generated docstrings.
    docstring = docstring.replace("torch.distributed.rpc.PyRRef", "torch.distributed.rpc.RRef")

This one has an assert to stop it but asserts are removed when the -OO flags are set so we need to move everything inside the check or replace None with ''
Like:

docstring = textwrap.dedent(getattr(fn, '__doc__', None) or '')

I see some asserts in test but I guess that would be irrelevant here.

I will also check if dependencies are compatible.

What do you think?

@albanD
Copy link
Collaborator

albanD commented Apr 27, 2022

Thanks for looking into this.
For 1. you can indeed add a None check.
For 2. I would simply remove the __doc__ = and leave that string at the top. Python will properly put it in the __doc__ field and it will also be properly removed in optimized mode.
For 3. I think we can guard this check with if __debug__:.

For testing, I think it is safe to assume that testing does not run in optimized mode.
But we do want to test this not to regress in the future. You can do something similar to this test:

def test_private_stateless_warns(self):
script = """
import torch
import warnings
with warnings.catch_warnings(record=True) as w:
from torch.nn.utils import _stateless
exit(len(w))
"""
try:
subprocess.check_output(
[sys.executable, '-W', 'all', '-c', script],
stderr=subprocess.STDOUT,
# On Windows, opening the subprocess with the default CWD makes `import torch`
# fail, so just set CWD to this script's directory
cwd=os.path.dirname(os.path.realpath(__file__)),)
except subprocess.CalledProcessError as e:
self.assertEqual(e.returncode, 1)
else:
self.assertTrue(False, "No warning was raised.")
and make sure that import torch does work in optimized mode.

@vitrioil
Copy link
Contributor

Cool.

Just one thing, for 2 removing __doc__ and keeping it as it is will work but not if format is used.

class Works:
    """I won't stay with -OO :)"""

class Problem:
    """I am not a docstring anymore {}""".format(":(")

Problem:

class Conv1d(_ConvNd):
__doc__ = r"""Applies a 1D convolution over an input signal composed of several input
planes.
In the simplest case, the output value of the layer with input size
:math:`(N, C_{\text{in}}, L)` and output :math:`(N, C_{\text{out}}, L_{\text{out}})` can be
precisely described as:
.. math::
\text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) +
\sum_{k = 0}^{C_{in} - 1} \text{weight}(C_{\text{out}_j}, k)
\star \text{input}(N_i, k)
where :math:`\star` is the valid `cross-correlation`_ operator,
:math:`N` is a batch size, :math:`C` denotes a number of channels,
:math:`L` is a length of signal sequence.
""" + r"""
This module supports :ref:`TensorFloat32<tf32_on_ampere>`.
* :attr:`stride` controls the stride for the cross-correlation, a single
number or a one-element tuple.
* :attr:`padding` controls the amount of padding applied to the input. It
can be either a string {{'valid', 'same'}} or a tuple of ints giving the
amount of implicit padding applied on both sides.
* :attr:`dilation` controls the spacing between the kernel points; also
known as the 脿 trous algorithm. It is harder to describe, but this `link`_
has a nice visualization of what :attr:`dilation` does.
{groups_note}
Note:
{depthwise_separable_note}
Note:
{cudnn_reproducibility_note}
Note:
``padding='valid'`` is the same as no padding. ``padding='same'`` pads
the input so the output has the shape as the input. However, this mode
doesn't support any stride values other than 1.
Args:
in_channels (int): Number of channels in the input image
out_channels (int): Number of channels produced by the convolution
kernel_size (int or tuple): Size of the convolving kernel
stride (int or tuple, optional): Stride of the convolution. Default: 1
padding (int, tuple or str, optional): Padding added to both sides of
the input. Default: 0
padding_mode (string, optional): ``'zeros'``, ``'reflect'``,
``'replicate'`` or ``'circular'``. Default: ``'zeros'``
dilation (int or tuple, optional): Spacing between kernel
elements. Default: 1
groups (int, optional): Number of blocked connections from input
channels to output channels. Default: 1
bias (bool, optional): If ``True``, adds a learnable bias to the
output. Default: ``True``
""".format(**reproducibility_notes, **convolution_notes) + r"""

So we might have to keep it as it is and then after declaration do the formatting if __doc__ exists for all occurrences.

class WorkAround:
    """I am a {} docstring that won't stay with -OO :)"""

if WorkAround.__doc__:
    WorkAround.__doc__ = WorkAround.__doc__.format("formatted")

@albanD
Copy link
Collaborator

albanD commented Apr 28, 2022

That sounds like a good plan! A bit unfortunate that we have to do that but hey!

@vitrioil
Copy link
Contributor

Created: #76619. Still there's a lot of changes pending. Should I make it a draft PR?

@albanD
Copy link
Collaborator

albanD commented May 2, 2022

Thanks for the fix!
I think we can open a new issue for the follow up improvements (as it now works, but not optimally).

Also I would be curious to discuss there the impact of the remaining docstrings.

@vitrioil
Copy link
Contributor

vitrioil commented May 2, 2022

Created #76659

facebook-github-bot pushed a commit that referenced this issue May 3, 2022
Summary:
Fixes #76034

This does not make python remove all `__doc__` because in some places `__doc__` is assigned to a string.

Example:
https://github.com/pytorch/pytorch/blob/04b3313379712098183dfe5bea002a5e43b5af48/torch/nn/modules/conv.py#L174-L233

Since there are quite a few of these, I will add all of them together in this PR later. (Basically still a lot of docstring will persist even with `-OO` enabled.)

Pull Request resolved: #76619
Approved by: https://github.com/albanD

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/f92cddd89030af9389f20d7c75c0831efbf9a2ba

Reviewed By: malfet

Differential Revision: D36073490

fbshipit-source-id: 49efc08a09facefa705fc25d1ad2bd368d86ca28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
actionable module: python frontend For issues relating to PyTorch's Python frontend triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants