Skip to content

[onnx] crash when exporting a model with Sequence module (encodeBlock: Assertion failled ) #19227

@gkossakowski

Description

@gkossakowski

🐛 Bug

The tracer in onnx exporter is failing with an assertion error in the exporter.

To Reproduce

Steps to reproduce the behavior:

  1. define a model:
class SomeModel(nn.Module):

    def __init__(self):
        super(SomeModel, self).__init__()
        dim = 5
        self.emb = nn.Embedding(10, dim)
        self.lin1 = nn.Linear(dim, 1)
        self.seq = nn.Sequential(
            self.emb,
            self.lin1,
        )

    def forward(self, input):
        return self.seq(input)

model = SomeModel()
  1. run export:
dummy_input = torch.tensor([2], dtype=torch.long)
dummy_input

torch.onnx.export(model, dummy_input, "model.onnx", verbose=True)

fails with:

[redacted]/vendor/local/lib/python2.7/site-packages/torch/onnx/utils.pyc in _export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, export_type, example_outputs, propagate)
    230     defer_weight_export = export_type is not ExportTypes.PROTOBUF_FILE
    231     if export_params:
--> 232         proto, export_map = graph.export(params, _onnx_opset_version, defer_weight_export, operator_export_type)
    233     else:
    234         proto, export_map = graph.export([], _onnx_opset_version, False, operator_export_type)

RuntimeError: torch/csrc/jit/export.cpp:296: encodeBlock: Assertion `b->inputs().size() >= num_initializers` failed.

Expected behavior

Model is exported to ONNX

Environment

Collecting environment information...
PyTorch version: 0.4.1
Is debug build: No
CUDA used to build PyTorch: 9.0.176

OS: Ubuntu 16.04.6 LTS
GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.11) 5.4.0 20160609
CMake version: Could not collect

Python version: 2.7
Is CUDA available: No
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA

Versions of relevant libraries:
[pip] msgpack-numpy==0.4.3.1
[pip] numpy==1.14.1
[pip] torch==0.4.1
[conda] Could not collect

Additional context

As a workaround one can inline layer definitions into the sequence:

class SomeModel(nn.Module):

    def __init__(self):
        super(SomeModel, self).__init__()
        dim = 5
        self.seq = nn.Sequential(
            nn.Embedding(10, dim),
            nn.Linear(dim, 1),
        )

    def forward(self, input):
        return self.seq(input)

model = SomeModel()

Metadata

Metadata

Assignees

Labels

high prioritymodule: assert failureThe issue involves an assert failuremodule: onnxRelated to torch.onnxtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions