Skip to content

Dynamo export: Qwen3 Omni model running into data-dependent expression errors #169009

@yuanyao-nv

Description

@yuanyao-nv

🐛 Describe the bug

I have the following toy model excerpted from the Qwen3OmniMoeAudioEncoder model and I'm hitting two data-dependent expression errors in the FX graph generation step.

import torch
import torch.nn as nn
import torch.nn.functional as F


def _get_feat_extract_output_lengths(input_lengths):
    input_lengths_leave = input_lengths % 100
    feat_lengths = (input_lengths_leave - 1) // 2 + 1
    output_lengths = ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
    return output_lengths


class AudioChunkProcessor(nn.Module):
    def __init__(
        self,
        num_mel_bins: int = 128,
        downsample_hidden_size: int = 512,
        d_model: int = 1280,
        n_window: int = 1500,
        conv_chunksize: int = 64,
    ):
        super().__init__()
        self.num_mel_bins = num_mel_bins
        self.n_window = n_window
        self.conv_chunksize = conv_chunksize

        self.conv2d1 = nn.Conv2d(1, downsample_hidden_size, 3, 2, padding=1)
        self.conv2d2 = nn.Conv2d(downsample_hidden_size, downsample_hidden_size, 3, 2, padding=1)
        self.conv2d3 = nn.Conv2d(downsample_hidden_size, downsample_hidden_size, 3, 2, padding=1)

        self.conv_out = nn.Linear(
            downsample_hidden_size * ((((num_mel_bins + 1) // 2 + 1) // 2 + 1) // 2),
            d_model,
            bias=False,
        )
    def forward(
        self,
        input_features: torch.Tensor,
        feature_lens: torch.LongTensor,
        aftercnn_lens: torch.LongTensor = None,
    ) -> torch.Tensor:
        aftercnn_lens = _get_feat_extract_output_lengths(feature_lens)

        chunk_num = torch.ceil(feature_lens / (self.n_window * 2)).long()

        chunk_lengths = torch.tensor(
            [self.n_window * 2] * chunk_num.sum(),
            dtype=torch.long,
            device=feature_lens.device,
        )
        tail_chunk_index = F.pad(chunk_num, (1, 0), value=-1).cumsum(0)[1:]
        chunk_lengths[tail_chunk_index] = feature_lens % (self.n_window * 2)
        chunk_lengths[chunk_lengths == 0] = self.n_window * 2

        chunk_list = input_features.T.split(chunk_lengths.tolist(), dim=0)

        padded_feature = nn.utils.rnn.pad_sequence(chunk_list, batch_first=True).transpose(1, 2)

        feature_lens_after_cnn = _get_feat_extract_output_lengths(chunk_lengths)
        padded_mask_after_cnn = nn.utils.rnn.pad_sequence(
            [torch.ones(length, dtype=torch.bool, device=padded_feature.device) for length in feature_lens_after_cnn],
            batch_first=True,
        )

        padded_feature = padded_feature.unsqueeze(1)

        padded_embeds = []
        for chunk in padded_feature.split(self.conv_chunksize, dim=0):
            padded_embed = F.gelu(self.conv2d1(chunk))
            padded_embed = F.gelu(self.conv2d2(padded_embed))
            padded_embed = F.gelu(self.conv2d3(padded_embed))
            padded_embeds.append(padded_embed)
        padded_embed = torch.cat(padded_embeds, dim=0)

        return padded_embed

if __name__ == "__main__":
    model = AudioChunkProcessor(
        num_mel_bins=128,
        downsample_hidden_size=480,
        d_model=1280,
        n_window=50,
        conv_chunksize=500,
    )

    input_features = torch.randn(128, 290)  # (num_mel_bins, total_time)
    feature_lens = torch.tensor([290])  # lengths of each audio

    output = model(input_features, feature_lens)
    torch.export.export(model, (input_features, feature_lens))

Error 1:

  File "/home/yuanyao/.local/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 7574, in _evaluate_expr
    raise self._make_data_dependent_error(
torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not extract specialized integer from data-dependent expression u0 (unhinted: u0).  (Size-like symbols: none)

Caused by: (_export/non_strict_utils.py:1066 in __torch_function__)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1

The following call raised this error:
  File "/workspace/Qwen3_omni/export_toy_audio.py", line 48, in forward
    [self.n_window * 2] * chunk_num.sum(),

Error 2:
If I change [self.n_window * 2] * chunk_num.sum() to a fixed number such as [self.n_window * 2] * 3, then I hit

  File "/home/yuanyao/.local/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 7574, in _evaluate_expr
    raise self._make_data_dependent_error(
torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression Max(u0, u1, u2) + 2 < 3 (unhinted: Max(u0, u1, u2) + 2 < 3).  (Size-like symbols: u0, u1, u2)

consider using data-dependent friendly APIs such as guard_or_false, guard_or_true and statically_known_trueCaused by: (_subclasses/fake_impls.py:1028 in conv)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0,u1,u2"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1

The following call raised this error:
  File "/workspace/Qwen3_omni/export_toy_audio.py", line 71, in forward
    padded_embed = F.gelu(self.conv2d1(chunk))

To fix the error, insert one of the following checks before this call:
  1. torch._check(2 + max(chunk_list[0].shape[0], chunk_list[1].shape[0], chunk_list[2].shape[0]) < 3)
  2. torch._check(2 + max(chunk_list[0].shape[0], chunk_list[1].shape[0], chunk_list[2].shape[0]) >= 3)

(These suggested fixes were derived by replacing `u0` with chunk_list[0].shape[0], `u1` with chunk_list[1].shape[0], `u2` with chunk_list[2].shape[0] in Max(u0, u1, u2) + 2 < 3 and its negation.)

Versions

Relevant packages:
torch 2.9.1
onnxscript 0.5.6

cc @chauhang @penguinwu @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions