Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot load certain function from dumped Torchscript file #42258

Open
mthrok opened this issue Jul 29, 2020 · 11 comments
Open

Cannot load certain function from dumped Torchscript file #42258

mthrok opened this issue Jul 29, 2020 · 11 comments
Assignees
Labels
days high priority oncall: jit Add this issue/PR to JIT oncall triage queue triage review
Projects

Comments

@mthrok
Copy link
Contributor

mthrok commented Jul 29, 2020

馃悰 Bug

JIT throws RuntimeError: classType INTERNAL ASSERT FAILED at "../torch/csrc/jit/python/pybind_utils.h":894, please report a bug to PyTorch., when loading certain function from dumped Torchscript object, (this does not happen on all the functions I tested.)

Interestingly, this does not happen if dumping scripted function to a file and loading the file in the same Python execution. This only happens when I separate dump and load into two separate Python calls.

To Reproduce

Steps to reproduce the behavior:

  1. conda intall torchaudio=0.6.0 -c pytorch

  2. Run the following script for the first time. It generates Torchscript objects in the directory foo.

    Script
    import os
    import tempfile
    from typing import Optional
    
    import torch
    import torchaudio
    
    
    def info(filepath: str) -> torchaudio.backend.sox_io_backend.AudioMetaData:
        return torchaudio.info(filepath)
    
    
    def load(
            filepath: str,
            frame_offset: int,
            num_frames: int,
            normalize: bool,
            channels_first: bool):
        return torchaudio.load(filepath, frame_offset, num_frames, normalize, channels_first)
    
    
    def save(
            filepath: str,
            tensor: torch.Tensor,
            sample_rate: int,
            channels_first: bool = True,
            compression: Optional[float] = None,
    ):
        torchaudio.save(filepath, tensor, sample_rate, channels_first, compression)
    
    
    def generate(output_dir):
        torchaudio.set_audio_backend('sox_io')
    
        funcs = [
            info,
            load,
            save,
        ]
    
        os.makedirs(output_dir, exist_ok=True)
        for func in funcs:
            torch.jit.script(func).save(os.path.join(output_dir, f'{func.__name__}.zip'))
    
    
    def validate(input_dir):
        torchaudio.set_audio_backend('sox_io')
    
        info_ = torch.jit.load(os.path.join(input_dir, f'info.zip'))
        load_ = torch.jit.load(os.path.join(input_dir, f'load.zip'))
        save_ = torch.jit.load(os.path.join(input_dir, f'save.zip'))
    
        sample_rate = 44100
        normalize = True
        channels_first = True
        with tempfile.TemporaryDirectory() as temp_dir:
            temp_file = os.path.join(temp_dir, 'test.wav')
            temp_data = torch.rand(2, sample_rate, dtype=torch.float32)
    
            save_(temp_file, temp_data, sample_rate, channels_first, 0.)
            info_(temp_file)
            load_(temp_file, 0, -1, normalize, channels_first)
    
    
    generate('foo')
    validate('foo')
  3. Now comment out generate('foo') and run the script again. It causes

    Traceback (most recent call last):
     File "foo.py", line 66, in <module>
       validate('foo')
     File "foo.py", line 61, in validate
       info_(temp_file)
     File "/home/moto/conda/envs/PY3.8-cuda101/lib/python3.8/site-packages/torch/nn/modules/module.py", line 726, in _call_impl
       result = self.forward(*input, **kwargs)
    RuntimeError: classType INTERNAL ASSERT FAILED at "../torch/csrc/jit/python/pybind_utils.h":894, please report a bug to PyTorch.
    
  4. Strangely only info function causes this error. so commenting out the info_(temp_file) would not cause the same issue. Here is the Python code and C++ code is found here. I do not know if this is related but one difference between this failing info function and the others is info function uses a Python class (AudioMetaData).

Expected behavior

Torchscript should work for info_(temp_file) too.

Environment

This happens in multiple environments.

Env 1 master build
$ python -m torch.utils.collect_env
Collecting environment information...
PyTorch version: 1.7.0a0+fd9205e
Is debug build: No
CUDA used to build PyTorch: 10.1

OS: Ubuntu 18.04.3 LTS
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
CMake version: version 3.10.2

Python version: 3.8
Is CUDA available: Yes
CUDA runtime version: 10.1.243
GPU models and configuration:
GPU 0: Quadro GP100
GPU 1: Quadro GP100

Nvidia driver version: 418.116.00
cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5

Versions of relevant libraries:
[pip3] numpy==1.18.5
[pip3] pytorch-sphinx-theme==0.0.24
[pip3] torch==1.7.0a0+fd9205e
[pip3] torchaudio==0.7.0a0+71a797d
[conda] magma-cuda101             2.5.2                         1    pytorch
[conda] mkl                       2020.1                      217
[conda] mkl-include               2020.1                      219    conda-forge
[conda] numpy                     1.18.5           py38h8854b6b_0    conda-forge
[conda] pytorch-sphinx-theme      0.0.24                    dev_0    <develop>
[conda] torch                     1.7.0a0+fd9205e           dev_0    <develop>
[conda] torchaudio                0.7.0a0+71a797d           dev_0    <develop>
Env 2 official release
$python -m torch.utils.collect_env
Collecting environment information...
PyTorch version: 1.6.0
Is debug build: No
CUDA used to build PyTorch: 10.2

OS: Ubuntu 18.04.3 LTS
GCC version: (Ubuntu 7.4.0-1ubuntu1~18.04.1) 7.4.0
CMake version: version 3.10.2

Python version: 3.6
Is CUDA available: No
CUDA runtime version: Could not collect
GPU models and configuration:
GPU 0: Quadro GP100
GPU 1: Quadro GP100

Nvidia driver version: 418.116.00
cuDNN version: Could not collect

Versions of relevant libraries:
[pip3] numpy==1.18.5
[pip3] torch==1.6.0
[pip3] torchaudio==0.6.0a0+f17ae39
[conda] blas                      1.0                         mkl
[conda] cudatoolkit               10.2.89              hfd86e86_1
[conda] mkl                       2020.1                      217
[conda] mkl-service               2.3.0            py36he904b0f_0
[conda] mkl_fft                   1.1.0            py36h23d657b_0
[conda] mkl_random                1.1.1            py36h0573a6f_0
[conda] numpy                     1.18.5           py36ha1c710e_0
[conda] numpy-base                1.18.5           py36hde5b4d6_0
[conda] pytorch                   1.6.0           py3.6_cuda10.2.89_cudnn7.6.5_0    pytorch
[conda] torchaudio                0.6.0                      py36    pytorch

Additional context

I encountered this issue when I work on BC check test pytorch/audio#838

cc @ezyang @gchanan @zou3519 @suo @gmagogsfm

@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Jul 29, 2020
@github-actions github-actions bot added this to Need triage in JIT Triage Jul 29, 2020
@wconstab wconstab added the days label Jul 30, 2020
@wconstab wconstab moved this from Need triage to Pending in JIT Triage Jul 30, 2020
@wconstab wconstab moved this from Pending to HIGH PRIORITY in JIT Triage Jul 30, 2020
@gmagogsfm
Copy link
Contributor

Hi mthrok,

Thanks for reporting this issue. Your guess is correct, it is exactly Python class AudioMetaData causing trouble in type parsing. If you are looking for a quick work-around, I think rewriting implementation of info without AudioMetaData should get you going.

The superficial issue is that when importing saved module, a type __torch__.torchaudio.backend.sox_io_backend.AudioMetaData is found and registered to a new CompilationUnit but then entire CompilationUnit seems discarded when converting to Python ScriptModule.

@SplitInfinity , could you provide some insight? I tried to replace the new CompilationUnit with global CompilationUnit from torch.jit._state and the error is gone, but then hits a later check when trying to get implementation of the aforementioned class:

RuntimeError: Unknown reference to ScriptClass __torch__.torchaudio.backend.sox_io_backend.AudioMetaData. Did you forget to import it?

@mthrok
Copy link
Contributor Author

mthrok commented Aug 5, 2020

Hi @gmagogsfm

Thanks for looking into this.

If you are looking for a quick work-around, I think rewriting implementation of info without AudioMetaData should get you going.

We will hold on taking the action because that change will be BC-braking on torchaudio side.
However I have a question, if we were to change the implementation, would dict type work?

@SplitInfinity
Copy link

@gmagogsfm

We maintain a global map of name -> Python class type for script classes we compile like AudioMetaData. Every time a class is compiled either explicitly or implicitly, a corresponding entry is added to this map. The main purpose is so that we can look up and call __new__ to create a new instance of the class when converting from IValue -> PyObject (e.g. when a function returns a script class, like the case of info_). __new__ is not scripted nor serialized.

So I think what's going on is that in addition to the function and the script class being in separate compilation units, we are not populating the aforementioned global map for script classes when we deserialize them. Even if we did that, we don't serialize __new__.

So yeah, I think that changing the implementation is the best idea.

@gmagogsfm
Copy link
Contributor

@mthrok I think using either a dictionary or a named tuple would both be OK.

@mthrok
Copy link
Contributor Author

mthrok commented Aug 6, 2020

@SplitInfinity

Thanks for the detail. I understand that changing implementation will make it work.
Is this the expected behavior? or is it something potentially improved in future?

@SplitInfinity
Copy link

Is this the expected behavior?

No, but it's probably not a use case that was foreseen based on the design. TorchScript classes are also still considered experimental.

or is it something potentially improved in future?

I'm working on a redesign of class types that will most likely get rid of the dependency of this feature on __new__, so it might be potentially improved in the future.

mthrok added a commit to pytorch/audio that referenced this issue Aug 6, 2020
This CC adds CI test to check the backward compatibility of Torchscript functions/object dumps.

The job first dumps Torchscript objects into file from torchaudio 0.6.0 release environment (x Python 3.6, 3.7, 3,8), then load & run the function in master build (in Python 3.6, 3.7, 3.8).

If there is a BC-breaking change in master build, (registration schema change), then the test should fail.

At this moment, `info` function does not work due to suspected bug in torch side, so the test is disabled for `info` function. See pytorch/pytorch#42258 for the detail. Once pytorch/pytorch#42258 is resolved we can enable it.
@mthrok
Copy link
Contributor Author

mthrok commented Aug 6, 2020

@SplitInfinity

Thanks for the clarification.

One more question.

No, but it's probably not a use case that was foreseen based on the design. TorchScript classes are also still considered experimental.

If it were to be fixed in the future, will the currently dumped files work? or is it could be uncertain at the moment?

@gmagogsfm
Copy link
Contributor

There actually is a work-around (with limitation) at the moment.

In the program that loads these saved modules, you can import definition of AudioMetaData (also decorating AudioMetaData with @torch.jit.script) and your program should run successfully. However, this does come with the limitation that your program must be in Python. So if you want to serve this model in some prod environment, it might not work well.

@SplitInfinity
Copy link

In the program that loads these saved modules, you can import definition of AudioMetaData (also decorating AudioMetaData with @torch.jit.script) and your program should run successfully. However, this does come with the limitation that your program must be in Python.

Did you try this? I wonder if there will be a naming conflict in the compilation unit.

@gmagogsfm
Copy link
Contributor

In the program that loads these saved modules, you can import definition of AudioMetaData (also decorating AudioMetaData with @torch.jit.script) and your program should run successfully. However, this does come with the limitation that your program must be in Python.

Did you try this? I wonder if there will be a naming conflict in the compilation unit.

I was wondering the same thing. But there is no name conflict and the failure is gone by adding @torch.jit.script to AudioMetaData class. I guess this is due to the fact that compilation unit used in importing is discarded?

@yf225
Copy link
Contributor

yf225 commented Aug 7, 2020

In the program that loads these saved modules, you can import definition of AudioMetaData (also decorating AudioMetaData with @torch.jit.script) and your program should run successfully. However, this does come with the limitation that your program must be in Python.

Did you try this? I wonder if there will be a naming conflict in the compilation unit.

Yeah I tried and this works:

test_common.py

from typing import Dict, List, Optional
import torch

@torch.jit.script
class TypedDataDict(object):
    def __init__(
        self,
        str_to_tensor: Optional[Dict[str, torch.Tensor]] = None,
        str_to_list_of_str: Optional[Dict[str, List[str]]] = None
    ):
        self.str_to_tensor = str_to_tensor
        self.str_to_list_of_str = str_to_list_of_str

test_save.py

import torch

from test_common import TypedDataDict

class TestModule(torch.nn.Module):
  def __init__(self):
    super().__init__()

  def forward(self, input: torch.Tensor):
    return TypedDataDict(None, None)

m = TestModule()
m_scripted = torch.jit.script(m)
m_scripted.save("TypedDataDict.pt")

test_load.py

import torch

from test_common import TypedDataDict

m_scripted = torch.jit.load("TypedDataDict.pt")
typed_data_dict = m_scripted(torch.tensor(1.))
print(typed_data_dict)

If I run test_save.py and then test_load.py, it is able to access the content of the TorchScript class object.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
days high priority oncall: jit Add this issue/PR to JIT oncall triage queue triage review
Projects
JIT Triage
  
In progress
Development

No branches or pull requests

6 participants