-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Closed
Labels
module: onnxRelated to torch.onnxRelated to torch.onnxoncall: exportoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Describe the bug
I'm using the script below for the ONNX export of the CosmosTransformer3DModel for the video2world pipeline. However, I run into an export error
# export script
import torch
from diffusers import CosmosTransformer3DModel
device = "cuda"
dtype = torch.bfloat16
model = CosmosTransformer3DModel.from_pretrained(
"nvidia/Cosmos-Predict2-2B-Video2World",
subfolder="transformer",
use_safetensors=True,
token=<hf-token>,
torch_dtype=dtype,
).to(device)
batch_size = 1
image_height = 704
image_width = 1280
latent_height = image_height // 8
latent_width = image_width // 8
latent_frames = 1#24
latent_channels = model.config.in_channels - 1
text_maxlen = 512
sample_input = (
torch.randn(
batch_size,
latent_channels,
latent_frames,
latent_height,
latent_width,
dtype=dtype,
device=device,
),
torch.tensor([1.0] * batch_size, dtype=dtype, device=device),
torch.randn(
batch_size, text_maxlen, model.config["text_embed_dim"], dtype=dtype, device=device
),
{
"fps": torch.tensor([16] * batch_size, dtype=dtype, device=device),
"padding_mask": torch.ones(1, 1, image_height, image_width, dtype=dtype, device=device),
"condition_mask": torch.randn(
batch_size,
1,
latent_frames,
latent_height,
latent_width,
dtype=dtype,
device=device,
)
},
)
input_names = [
"hidden_states",
"timestep",
"encoder_hidden_states",
"padding_mask",
"fps",
"condition_mask",
]
output_names = ["latent"]
dynamic_shapes = (
{0: "B", 2: "latent_frames", 3: "latent_H", 4: "latent_W"},
{0: "B"},
{0: "B"},
{0: "B", 2: "H", 3: "W"},
{0: "B"},
{0: "B", 2: "latent_frames", 3: "latent_H", 4: "latent_W"},
)
# run inference for sanity check
out = model(sample_input[0], sample_input[1], sample_input[2], **sample_input[3])
torch.onnx.export(
model,
sample_input,
"cosmos_transformer_vid2world/model.onnx",
export_params=True,
opset_version=19,
do_constant_folding=True,
input_names=input_names,
output_names=output_names,
dynamic_shapes=dynamic_shapes,
verbose=False,
dynamo=True,
)
Error stacktrace is below
Traceback (most recent call last):
File "/usr/local/lib/python3.12/dist-packages/torch/onnx/_internal/exporter/_capture_strategies.py", line 118, in __call__
exported_program = self._capture(model, args, kwargs, dynamic_shapes)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/onnx/_internal/exporter/_capture_strategies.py", line 202, in _capture
return torch.export.export(
^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/export/__init__.py", line 319, in export
raise e
File "/usr/local/lib/python3.12/dist-packages/torch/export/__init__.py", line 286, in export
return _export(
^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 1159, in wrapper
raise e
File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 1125, in wrapper
ep = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/export/exported_program.py", line 123, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 2172, in _export
ep = _export_for_training(
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 1159, in wrapper
raise e
File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 1125, in wrapper
ep = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/export/exported_program.py", line 123, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 2033, in _export_for_training
export_artifact = export_func(
^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 1975, in _non_strict_export
aten_export_artifact = _to_aten_func( # type: ignore[operator]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 1760, in _export_to_aten_ir_make_fx
gm, graph_signature = transform(_make_fx_helper)(
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 1901, in _aot_export_non_strict
gm, sig = aot_export(wrapped_mod, args, kwargs=kwargs, **flags)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 1679, in _make_fx_helper
gm = make_fx(
^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2295, in wrapped
return make_fx_tracer.trace(f, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2233, in trace
return self._trace_inner(f, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2204, in _trace_inner
t = dispatch_trace(
^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_compile.py", line 51, in inner
return disable_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 893, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1221, in dispatch_trace
graph = tracer.trace(root, concrete_args) # type: ignore[arg-type]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1792, in trace
res = super().trace(root, concrete_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/_symbolic_trace.py", line 850, in trace
(self.create_arg(fn(*args)),),
^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1279, in wrapped
out = f(*tensors) # type:ignore[call-arg]
^^^^^^^^^^^
File "<string>", line 1, in <lambda>
File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 1583, in wrapped_fn
return tuple(flat_fn(*args))
^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/utils.py", line 184, in flat_fn
tree_out = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 906, in functional_call
out = mod(*args[params_len:], **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/_symbolic_trace.py", line 825, in module_call_wrapper
return self.call_module(mod, forward, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1862, in call_module
return Tracer.call_module(self, m, forward, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/_symbolic_trace.py", line 542, in call_module
ret_val = forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/_symbolic_trace.py", line 818, in forward
return _orig_module_call(mod, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1767, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1778, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 1885, in forward
tree_out = mod(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/_symbolic_trace.py", line 825, in module_call_wrapper
return self.call_module(mod, forward, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1862, in call_module
return Tracer.call_module(self, m, forward, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/_symbolic_trace.py", line 542, in call_module
ret_val = forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/_symbolic_trace.py", line 818, in forward
return _orig_module_call(mod, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1767, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1778, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/diffusers/models/transformers/transformer_cosmos.py", line 566, in forward
hidden_states = block(
^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/_symbolic_trace.py", line 825, in module_call_wrapper
return self.call_module(mod, forward, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1862, in call_module
return Tracer.call_module(self, m, forward, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/_symbolic_trace.py", line 542, in call_module
ret_val = forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/_symbolic_trace.py", line 818, in forward
return _orig_module_call(mod, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1767, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1778, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/diffusers/models/transformers/transformer_cosmos.py", line 275, in forward
attn_output = self.attn1(norm_hidden_states, image_rotary_emb=image_rotary_emb)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/_symbolic_trace.py", line 825, in module_call_wrapper
return self.call_module(mod, forward, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1862, in call_module
return Tracer.call_module(self, m, forward, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/_symbolic_trace.py", line 542, in call_module
ret_val = forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/_symbolic_trace.py", line 818, in forward
return _orig_module_call(mod, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1767, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1778, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/diffusers/models/attention_processor.py", line 605, in forward
return self.processor(
^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/diffusers/models/transformers/transformer_cosmos.py", line 206, in __call__
hidden_states = F.scaled_dot_product_attention(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1327, in __torch_function__
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1374, in __torch_function__
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_export/non_strict_utils.py", line 976, in __torch_function__
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 925, in handler
return torch._library.utils.handle_dispatch_mode(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_library/utils.py", line 296, in handle_dispatch_mode
return curr_mode.__torch_dispatch__(op_overload, overload_types, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/utils/_stats.py", line 27, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1429, in __torch_dispatch__
return proxy_call(self, func, self.pre_dispatch, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 922, in proxy_call
out = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 806, in __call__
return self._op(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/sym_node.py", line 536, in guard_bool
r = self.evaluate()
^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/sym_node.py", line 510, in evaluate
return self.shape_env.evaluate_sym_node(self, size_oblivious)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 7191, in evaluate_sym_node
return self.evaluate_expr(
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 7291, in evaluate_expr
return self._inner_evaluate_expr(
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/recording.py", line 272, in wrapper
return retlog(fn(*args, **kwargs))
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 7314, in _inner_evaluate_expr
return self._evaluate_expr(
^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 7538, in _evaluate_expr
raise self._make_data_dependent_error(
torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq(128, u0) (unhinted: Eq(128, u0)). (Size-like symbols: u0)
Caused by: (_ops.py:806 in __call__)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing
For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
The following call raised this error:
File "/usr/local/lib/python3.12/dist-packages/diffusers/models/transformers/transformer_cosmos.py", line 206, in __call__
hidden_states = F.scaled_dot_product_attention(
To fix the error, insert one of the following checks before this call:
1. torch._check(128 == key.shape[3])
2. torch._check(128 != key.shape[3])
(These suggested fixes were derived by replacing `u0` with key.shape[3] in Eq(128, u0) and its negation.)
The error above occurred when calling torch.export.export. If you would like to view some more information about this error, and get a list of all other errors that may occur in your export call, you can replace your `export()` call with `draft_export()`.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "test_transformer_export_vid2world.py", line 83, in <module>
torch.onnx.export(
File "/usr/local/lib/python3.12/dist-packages/torch/onnx/__init__.py", line 367, in export
return _compat.export_compat(
^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/onnx/_internal/exporter/_compat.py", line 119, in export_compat
onnx_program = _core.export(
^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/onnx/_internal/exporter/_flags.py", line 20, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/onnx/_internal/exporter/_core.py", line 1367, in export
raise _errors.TorchExportError(
torch.onnx._internal.exporter._errors.TorchExportError: Failed to export the model with torch.export. This is step 1/3 of exporting the model to ONNX. Next steps:
- Modify the model code for `torch.export.export` to succeed. Refer to https://pytorch.org/docs/stable/generated/exportdb/index.html for more information.
- Debug `torch.export.export` and summit a PR to PyTorch.
- Create an issue in the PyTorch GitHub repository against the *torch.export* component and attach the full error stack as well as reproduction scripts.
## Exception summary
<class 'torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode'>: Could not guard on data-dependent expression Eq(128, u0) (unhinted: Eq(128, u0)). (Size-like symbols: u0)
Caused by: (_ops.py:806 in __call__)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing
For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
The following call raised this error:
File "/usr/local/lib/python3.12/dist-packages/diffusers/models/transformers/transformer_cosmos.py", line 206, in __call__
hidden_states = F.scaled_dot_product_attention(
To fix the error, insert one of the following checks before this call:
1. torch._check(128 == key.shape[3])
2. torch._check(128 != key.shape[3])
(These suggested fixes were derived by replacing `u0` with key.shape[3] in Eq(128, u0) and its negation.)
The error above occurred when calling torch.export.export. If you would like to view some more information about this error, and get a list of all other errors that may occur in your export call, you can replace your `export()` call with `draft_export()`.
(Refer to the full stack trace above for more information.)
Versions
torch==2.8
cc @justinchuby @titaiwangms @chauhang @penguinwu @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4
Metadata
Metadata
Assignees
Labels
module: onnxRelated to torch.onnxRelated to torch.onnxoncall: exportoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module