-
Notifications
You must be signed in to change notification settings - Fork 26.2k
Open
Description
🐛 Describe the bug
I have the following toy model excerpted from the Qwen3OmniMoeAudioEncoder model and I'm hitting two data-dependent expression errors in the FX graph generation step.
import torch
import torch.nn as nn
import torch.nn.functional as F
def _get_feat_extract_output_lengths(input_lengths):
input_lengths_leave = input_lengths % 100
feat_lengths = (input_lengths_leave - 1) // 2 + 1
output_lengths = ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
return output_lengths
class AudioChunkProcessor(nn.Module):
def __init__(
self,
num_mel_bins: int = 128,
downsample_hidden_size: int = 512,
d_model: int = 1280,
n_window: int = 1500,
conv_chunksize: int = 64,
):
super().__init__()
self.num_mel_bins = num_mel_bins
self.n_window = n_window
self.conv_chunksize = conv_chunksize
self.conv2d1 = nn.Conv2d(1, downsample_hidden_size, 3, 2, padding=1)
self.conv2d2 = nn.Conv2d(downsample_hidden_size, downsample_hidden_size, 3, 2, padding=1)
self.conv2d3 = nn.Conv2d(downsample_hidden_size, downsample_hidden_size, 3, 2, padding=1)
self.conv_out = nn.Linear(
downsample_hidden_size * ((((num_mel_bins + 1) // 2 + 1) // 2 + 1) // 2),
d_model,
bias=False,
)
def forward(
self,
input_features: torch.Tensor,
feature_lens: torch.LongTensor,
aftercnn_lens: torch.LongTensor = None,
) -> torch.Tensor:
aftercnn_lens = _get_feat_extract_output_lengths(feature_lens)
chunk_num = torch.ceil(feature_lens / (self.n_window * 2)).long()
chunk_lengths = torch.tensor(
[self.n_window * 2] * chunk_num.sum(),
dtype=torch.long,
device=feature_lens.device,
)
tail_chunk_index = F.pad(chunk_num, (1, 0), value=-1).cumsum(0)[1:]
chunk_lengths[tail_chunk_index] = feature_lens % (self.n_window * 2)
chunk_lengths[chunk_lengths == 0] = self.n_window * 2
chunk_list = input_features.T.split(chunk_lengths.tolist(), dim=0)
padded_feature = nn.utils.rnn.pad_sequence(chunk_list, batch_first=True).transpose(1, 2)
feature_lens_after_cnn = _get_feat_extract_output_lengths(chunk_lengths)
padded_mask_after_cnn = nn.utils.rnn.pad_sequence(
[torch.ones(length, dtype=torch.bool, device=padded_feature.device) for length in feature_lens_after_cnn],
batch_first=True,
)
padded_feature = padded_feature.unsqueeze(1)
padded_embeds = []
for chunk in padded_feature.split(self.conv_chunksize, dim=0):
padded_embed = F.gelu(self.conv2d1(chunk))
padded_embed = F.gelu(self.conv2d2(padded_embed))
padded_embed = F.gelu(self.conv2d3(padded_embed))
padded_embeds.append(padded_embed)
padded_embed = torch.cat(padded_embeds, dim=0)
return padded_embed
if __name__ == "__main__":
model = AudioChunkProcessor(
num_mel_bins=128,
downsample_hidden_size=480,
d_model=1280,
n_window=50,
conv_chunksize=500,
)
input_features = torch.randn(128, 290) # (num_mel_bins, total_time)
feature_lens = torch.tensor([290]) # lengths of each audio
output = model(input_features, feature_lens)
torch.export.export(model, (input_features, feature_lens))
Error 1:
File "/home/yuanyao/.local/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 7574, in _evaluate_expr
raise self._make_data_dependent_error(
torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not extract specialized integer from data-dependent expression u0 (unhinted: u0). (Size-like symbols: none)
Caused by: (_export/non_strict_utils.py:1066 in __torch_function__)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing
For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
The following call raised this error:
File "/workspace/Qwen3_omni/export_toy_audio.py", line 48, in forward
[self.n_window * 2] * chunk_num.sum(),
Error 2:
If I change [self.n_window * 2] * chunk_num.sum() to a fixed number such as [self.n_window * 2] * 3, then I hit
File "/home/yuanyao/.local/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 7574, in _evaluate_expr
raise self._make_data_dependent_error(
torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression Max(u0, u1, u2) + 2 < 3 (unhinted: Max(u0, u1, u2) + 2 < 3). (Size-like symbols: u0, u1, u2)
consider using data-dependent friendly APIs such as guard_or_false, guard_or_true and statically_known_trueCaused by: (_subclasses/fake_impls.py:1028 in conv)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0,u1,u2"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing
For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
The following call raised this error:
File "/workspace/Qwen3_omni/export_toy_audio.py", line 71, in forward
padded_embed = F.gelu(self.conv2d1(chunk))
To fix the error, insert one of the following checks before this call:
1. torch._check(2 + max(chunk_list[0].shape[0], chunk_list[1].shape[0], chunk_list[2].shape[0]) < 3)
2. torch._check(2 + max(chunk_list[0].shape[0], chunk_list[1].shape[0], chunk_list[2].shape[0]) >= 3)
(These suggested fixes were derived by replacing `u0` with chunk_list[0].shape[0], `u1` with chunk_list[1].shape[0], `u2` with chunk_list[2].shape[0] in Max(u0, u1, u2) + 2 < 3 and its negation.)
Versions
Relevant packages:
torch 2.9.1
onnxscript 0.5.6
cc @chauhang @penguinwu @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4