-
Notifications
You must be signed in to change notification settings - Fork 25.2k
Closed
Labels
high prioritymodule: assert failureThe issue involves an assert failureThe issue involves an assert failuremodule: onnxRelated to torch.onnxRelated to torch.onnxtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Bug
The tracer in onnx exporter is failing with an assertion error in the exporter.
To Reproduce
Steps to reproduce the behavior:
- define a model:
class SomeModel(nn.Module):
def __init__(self):
super(SomeModel, self).__init__()
dim = 5
self.emb = nn.Embedding(10, dim)
self.lin1 = nn.Linear(dim, 1)
self.seq = nn.Sequential(
self.emb,
self.lin1,
)
def forward(self, input):
return self.seq(input)
model = SomeModel()
- run export:
dummy_input = torch.tensor([2], dtype=torch.long)
dummy_input
torch.onnx.export(model, dummy_input, "model.onnx", verbose=True)
fails with:
[redacted]/vendor/local/lib/python2.7/site-packages/torch/onnx/utils.pyc in _export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, export_type, example_outputs, propagate)
230 defer_weight_export = export_type is not ExportTypes.PROTOBUF_FILE
231 if export_params:
--> 232 proto, export_map = graph.export(params, _onnx_opset_version, defer_weight_export, operator_export_type)
233 else:
234 proto, export_map = graph.export([], _onnx_opset_version, False, operator_export_type)
RuntimeError: torch/csrc/jit/export.cpp:296: encodeBlock: Assertion `b->inputs().size() >= num_initializers` failed.
Expected behavior
Model is exported to ONNX
Environment
Collecting environment information...
PyTorch version: 0.4.1
Is debug build: No
CUDA used to build PyTorch: 9.0.176
OS: Ubuntu 16.04.6 LTS
GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.11) 5.4.0 20160609
CMake version: Could not collect
Python version: 2.7
Is CUDA available: No
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
Versions of relevant libraries:
[pip] msgpack-numpy==0.4.3.1
[pip] numpy==1.14.1
[pip] torch==0.4.1
[conda] Could not collect
Additional context
As a workaround one can inline layer definitions into the sequence:
class SomeModel(nn.Module):
def __init__(self):
super(SomeModel, self).__init__()
dim = 5
self.seq = nn.Sequential(
nn.Embedding(10, dim),
nn.Linear(dim, 1),
)
def forward(self, input):
return self.seq(input)
model = SomeModel()
hailsham
Metadata
Metadata
Assignees
Labels
high prioritymodule: assert failureThe issue involves an assert failureThe issue involves an assert failuremodule: onnxRelated to torch.onnxRelated to torch.onnxtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module