Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.bartlett_window not jitable #32358

Open
vincentqb opened this issue Jan 17, 2020 · 5 comments
Open

torch.bartlett_window not jitable #32358

vincentqb opened this issue Jan 17, 2020 · 5 comments
Labels
oncall: jit Add this issue/PR to JIT oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@vincentqb
Copy link
Contributor

vincentqb commented Jan 17, 2020

馃悰 Bug

In [1]: import torch                                                                                                                                                                                      

In [2]: method = torch.bartlett_window                                                                                                                                                                    

In [3]: jit_method = torch.jit.script(method)                                                                                                                                                             
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-3-e9dbd59dab67> in <module>
----> 1 jit_method = torch.jit.script(method)

~/anaconda3/envs/audio-built/lib/python3.7/site-packages/torch/jit/__init__.py in script(obj, optimize, _frames_up, _rcb)
   1255         return torch.jit._recursive.recursive_script(obj)
   1256 
-> 1257     qualified_name = _qualified_name(obj)
   1258     if inspect.isclass(obj):
   1259         # If this type is a `nn.Module` subclass, they probably meant to pass

~/anaconda3/envs/audio-built/lib/python3.7/site-packages/torch/_jit_internal.py in _qualified_name(obj)
    694     if module_name is None:
    695         raise RuntimeError("Could not get qualified name for class '{}': "
--> 696                            "__module__ can't be None.".format(name))
    697 
    698     # if getattr(sys.modules[module_name], name) is not obj:

RuntimeError: Could not get qualified name for class 'bartlett_window': __module__ can't be None.

Environment

PyTorch version: 1.4.0
Is debug build: No
CUDA used to build PyTorch: None

OS: Mac OSX 10.14.6
GCC version: Could not collect
CMake version: version 3.15.2

Python version: 3.7
Is CUDA available: No
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA

Versions of relevant libraries:
[pip3] numpy==1.16.4
[pip3] torch==1.1.0.post2
[pip3] torchvision==0.3.0
[conda] blas 1.0 mkl
[conda] mkl 2019.4 233
[conda] mkl-service 2.3.0 py37hfbe908c_0
[conda] mkl_fft 1.0.15 py37h5e564d8_0
[conda] mkl_random 1.1.0 py37ha771720_0
[conda] pytorch 1.4.0 py3.7_0 pytorch
[conda] torchaudio 0.5.0a0+2c49528 pypi_0 pypi

Additional context

torchaudio jitability test is failing for dither, see pytorch/audio#417.

cc @suo

@vincentqb vincentqb added the oncall: jit Add this issue/PR to JIT oncall triage queue label Jan 17, 2020
@zdevito
Copy link
Contributor

zdevito commented Jan 19, 2020

I think the dither error and this error are different. I do not think we support scripting builtins like bartlet_window directly. But the error in torch audio seems legit.

@zdevito
Copy link
Contributor

zdevito commented Jan 19, 2020

Actually, it does look to be a function. Not sure what is happening here. @driazati - can you look to see how recursive scripting is handling this python function?

@zdevito zdevito added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jan 19, 2020
@driazati
Copy link
Contributor

driazati commented Jan 21, 2020

Since it's bound directly from C++ there's no code / module to compile. We could make script construct a graph with a single op (the call to the builtin) but it's complicated around things like overloaded ops (e.g. torch.add). Overloaded builtins could work if passed to torch.jit.trace with some sample inputs but that doesn't seem like a great API if users are using script most of the time.

@vincentqb for a quick workaround, you can wrap torch.bartlett_window in a @script function

def my_bartlett_window(x: int):
    return torch.bartlett_window(x)

@torch.jit.script
def my_script_function(x: int):
    return my_bartlett_window(x)

@vincentqb
Copy link
Contributor Author

@vincentqb for a quick workaround, you can wrap torch.bartlett_window in a @script function

def my_bartlett_window(x: int):
    return torch.bartlett_window(x)

@torch.jit.script
def my_script_function(x: int):
    return my_bartlett_window(x)

In the interpreter, this is indeed successful. However, calling the function with an integer value different from 0 or 1 yields an error. See code below, and circleci.

In [1]: import torch                                                                                                                                                                                      

In [2]: def my_bartlett_window(x: int): 
   ...:     return torch.bartlett_window(x) 

In [3]: torch.jit.script(my_bartlett_window)                                                                                                                                                              
Out[3]: <torch.jit.ScriptFunction at 0x126ebcbf0>

In [14]: torch.jit.script(my_bartlett_window)(1)                                                                                                                                 
Out[14]: tensor([1.])

In [15]: torch.jit.script(my_bartlett_window)(2)                                                                                                                                 
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-15-3490fd8fd399> in <module>
----> 1 torch.jit.script(my_bartlett_window)(2)

RuntimeError: result type Float can't be cast to the desired output type Long
The above operation failed in interpreter.
Traceback (most recent call last):
  File "<ipython-input-12-bcedcc6c5cd2>", line 2
def my_bartlett_window(x: int):
    return torch.bartlett_window(x)
           ~~~~~~~~~~~~~~~~~~~~~ <--- HERE

@driazati
Copy link
Contributor

driazati commented Jan 27, 2020

There are some problems with how TensorOptions are handled in TorchScript vs Eager that are broken currently. You can skirt around these by manually specifying the dtype to the torch.bartlett_window until they get fixed

def my_bartlett_window(x: int):
    return torch.bartlett_window(x, dtype=torch.float)

@torch.jit.script
def my_script_function(x: int):
    return my_bartlett_window(x)

print(my_bartlett_window(2))
print(my_script_function(2))

vincentqb added a commit to vincentqb/audio that referenced this issue Jan 27, 2020
vincentqb added a commit to pytorch/audio that referenced this issue Jan 29, 2020
* workaround for bartlett_window pytorch/pytorch#32358 (comment)

* only change dtype.
@driazati driazati removed their assignment Nov 14, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: jit Add this issue/PR to JIT oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants