Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Function(s) expecting a tuple argument don't accept generators #69207

Open
celynw opened this issue Dec 1, 2021 · 1 comment
Open

Function(s) expecting a tuple argument don't accept generators #69207

celynw opened this issue Dec 1, 2021 · 1 comment
Labels
module: nn Related to torch.nn Stale triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@celynw
Copy link

celynw commented Dec 1, 2021

πŸ› Bug

From the docs for torch.interpolate, these arguments accept tuples:

  • ...
  • size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int]) – output spatial size.
  • scale_factor (float or Tuple[float]) – multiplier for spatial size. If scale_factor is a tuple, its length has to match input.dim().
  • ...

However, if size or scale_factor are generators the function fails.

I am not sure if this extends to other similar functions.

To Reproduce

Steps to reproduce the behavior:

size argument

import torch
import torch.nn.functional as F

t = torch.rand([1, 2, 4, 6])
newSize = t.shape[-2:]
F.interpolate(t, size=(d * 2 for d in newSize))

produces error

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/redacted/torch/nn/functional.py", line 3712, in interpolate
    return torch._C._nn.upsample_nearest2d(input, output_size, scale_factors)
TypeError: upsample_nearest2d() received an invalid combination of arguments - got (Tensor, list, NoneType), but expected one of:
 * (Tensor input, tuple of ints output_size, tuple of floats scale_factors)
      didn't match because some of the arguments have invalid types: (Tensor, list, NoneType)
 * (Tensor input, tuple of ints output_size, float scales_h, float scales_w, *, Tensor out)

scale_factor argument

import torch
import torch.nn.functional as F

t = torch.rand([1, 2, 4, 6])
scale = [1, 1]
F.interpolate(t, scale_factor=(i * 2 for i in scale))

produces error

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/redacted/torch/nn/functional.py", line 3678, in interpolate
    if math.floor(scale) != scale:
TypeError: must be real number, not generator

Expected behavior

The generators would at least be attempted to be evaluated to lists/tuples. Even in the case of the size argument, the traceback states that it has received a list rather than a generator!

In both cases, resolving the generator expressions to either tuple or a list works perfectly.

Environment

PyTorch version: 1.10.0
Is debug build: False
CUDA used to build PyTorch: 11.3
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.3 LTS (x86_64)
GCC version: (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0
Clang version: 10.0.0-4ubuntu1 
CMake version: version 3.16.3
Libc version: glibc-2.31

Python version: 3.8.12 (default, Oct 12 2021, 13:49:34)  [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-5.11.0-41-generic-x86_64-with-glibc2.17
Is CUDA available: True
CUDA runtime version: 11.5.119
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3090
Nvidia driver version: 495.29.05
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.3.1
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.3.1
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.3.1
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.3.1
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.3.1
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.3.1
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.3.1
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] mypy==0.910
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.21.4
[pip3] pytorch-lightning==1.5.1
[pip3] torch==1.10.0
[pip3] torchmetrics==0.5.1
[pip3] torchvision==0.11.1
[conda] blas                      1.0                         mkl  
[conda] cudatoolkit               11.3.1               h2bc3f7f_2  
[conda] ffmpeg                    4.3                  hf484d3e_0    pytorch
[conda] mkl                       2021.3.0           h06a4308_520  
[conda] mkl-service               2.4.0            py38h7f8727e_0  
[conda] mkl_fft                   1.3.1            py38hd3c417c_0  
[conda] mkl_random                1.2.2            py38h51133e4_0  
[conda] numpy                     1.21.2           py38h20f2e39_0  
[conda] numpy-base                1.21.2           py38h79a1101_0  
[conda] pytorch                   1.10.0          py3.8_cuda11.3_cudnn8.2.0_0    pytorch
[conda] pytorch-lightning         1.5.1                    pypi_0    pypi
[conda] pytorch-mutex             1.0                        cuda    pytorch
[conda] torchmetrics              0.5.1                    pypi_0    pypi
[conda] torchvision               0.11.1               py38_cu113    pytorch

Additional context

The reason I was doing this in the first place is that I thought using parentheses would be 'tuple' comprehension rather than list comprehension. I have now learned that this is actually a generator expression.

I could easily use square brackets instead and move on - it's just I also expected this to work and the error message is a bit strange.

cc @albanD @mruberry @jbschlosser @walterddr @kshitij12345

@H-Huang H-Huang added module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Dec 1, 2021
@albanD
Copy link
Collaborator

albanD commented Dec 1, 2021

I'm not sure it will be easy to support generators for such code.

@github-actions github-actions bot added the Stale label Jan 30, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: nn Related to torch.nn Stale triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants