New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cannot load certain function from dumped Torchscript file #42258
Comments
Hi mthrok, Thanks for reporting this issue. Your guess is correct, it is exactly Python class The superficial issue is that when importing saved module, a type @SplitInfinity , could you provide some insight? I tried to replace the new CompilationUnit with global CompilationUnit from torch.jit._state and the error is gone, but then hits a later check when trying to get implementation of the aforementioned class:
|
Hi @gmagogsfm Thanks for looking into this.
We will hold on taking the action because that change will be BC-braking on |
We maintain a global map of So I think what's going on is that in addition to the function and the script class being in separate compilation units, we are not populating the aforementioned global map for script classes when we deserialize them. Even if we did that, we don't serialize So yeah, I think that changing the implementation is the best idea. |
@mthrok I think using either a dictionary or a named tuple would both be OK. |
Thanks for the detail. I understand that changing implementation will make it work. |
No, but it's probably not a use case that was foreseen based on the design. TorchScript classes are also still considered experimental.
I'm working on a redesign of class types that will most likely get rid of the dependency of this feature on |
This CC adds CI test to check the backward compatibility of Torchscript functions/object dumps. The job first dumps Torchscript objects into file from torchaudio 0.6.0 release environment (x Python 3.6, 3.7, 3,8), then load & run the function in master build (in Python 3.6, 3.7, 3.8). If there is a BC-breaking change in master build, (registration schema change), then the test should fail. At this moment, `info` function does not work due to suspected bug in torch side, so the test is disabled for `info` function. See pytorch/pytorch#42258 for the detail. Once pytorch/pytorch#42258 is resolved we can enable it.
Thanks for the clarification. One more question.
If it were to be fixed in the future, will the currently dumped files work? or is it could be uncertain at the moment? |
There actually is a work-around (with limitation) at the moment. In the program that loads these saved modules, you can import definition of |
Did you try this? I wonder if there will be a naming conflict in the compilation unit. |
I was wondering the same thing. But there is no name conflict and the failure is gone by adding @torch.jit.script to AudioMetaData class. I guess this is due to the fact that compilation unit used in importing is discarded? |
Yeah I tried and this works: test_common.py from typing import Dict, List, Optional
import torch
@torch.jit.script
class TypedDataDict(object):
def __init__(
self,
str_to_tensor: Optional[Dict[str, torch.Tensor]] = None,
str_to_list_of_str: Optional[Dict[str, List[str]]] = None
):
self.str_to_tensor = str_to_tensor
self.str_to_list_of_str = str_to_list_of_str test_save.py import torch
from test_common import TypedDataDict
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, input: torch.Tensor):
return TypedDataDict(None, None)
m = TestModule()
m_scripted = torch.jit.script(m)
m_scripted.save("TypedDataDict.pt") test_load.py import torch
from test_common import TypedDataDict
m_scripted = torch.jit.load("TypedDataDict.pt")
typed_data_dict = m_scripted(torch.tensor(1.))
print(typed_data_dict) If I run |
馃悰 Bug
JIT throws
RuntimeError: classType INTERNAL ASSERT FAILED at "../torch/csrc/jit/python/pybind_utils.h":894, please report a bug to PyTorch.
, when loading certain function from dumped Torchscript object, (this does not happen on all the functions I tested.)Interestingly, this does not happen if dumping scripted function to a file and loading the file in the same Python execution. This only happens when I separate dump and load into two separate Python calls.
To Reproduce
Steps to reproduce the behavior:
conda intall torchaudio=0.6.0 -c pytorch
Run the following script for the first time. It generates Torchscript objects in the directory
foo
.Script
Now comment out
generate('foo')
and run the script again. It causesStrangely only
info
function causes this error. so commenting out theinfo_(temp_file)
would not cause the same issue. Here is the Python code and C++ code is found here. I do not know if this is related but one difference between this failinginfo
function and the others isinfo
function uses a Python class (AudioMetaData
).Expected behavior
Torchscript should work for
info_(temp_file)
too.Environment
This happens in multiple environments.
Env 1 master build
Env 2 official release
Additional context
I encountered this issue when I work on BC check test pytorch/audio#838
cc @ezyang @gchanan @zou3519 @suo @gmagogsfm
The text was updated successfully, but these errors were encountered: