Description
Describe the issue
**Runtime error before training starts. **
Traceback (most recent call last):
File "/workspace/optimum/./examples/onnxruntime/training/language-modeling/run_clm.py", line 671, in
main()
File "/workspace/optimum/./examples/onnxruntime/training/language-modeling/run_clm.py", line 618, in main
train_result = trainer.train(resume_from_checkpoint=checkpoint)
File "/workspace/optimum/optimum/onnxruntime/trainer.py", line 408, in train
return inner_training_loop(
File "/workspace/optimum/optimum/onnxruntime/trainer.py", line 734, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/transformers/trainer.py", line 3485, in training_step
loss = self.compute_loss(model, inputs)
File "/workspace/optimum/optimum/onnxruntime/trainer.py", line 301, in compute_loss
return super().compute_loss(model_with_loss, inputs, return_outputs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/transformers/trainer.py", line 3532, in compute_loss
outputs = model(**inputs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/accelerate/utils/operations.py", line 823, in forward
return model_forward(*args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/accelerate/utils/operations.py", line 811, in call
return convert_to_fp32(self.model_forward(*args, **kwargs))
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
return func(*args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/onnxruntime/training/ortmodule/_utils.py", line 388, in _forward
return ortmodule._torch_module.forward(*inputs, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/onnxruntime/training/ortmodule/_utils.py", line 368, in _forward
return torch_module_ort._execution_manager(torch_module_ort.is_training()).forward(*inputs, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/onnxruntime/training/ortmodule/_training_manager.py", line 326, in forward
self._fallback_manager.handle_exception(
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/onnxruntime/training/ortmodule/_fallback.py", line 157, in handle_exception
raise exception
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/onnxruntime/training/ortmodule/_training_manager.py", line 268, in forward
self._build_graph(graph_transformer_config)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/onnxruntime/training/ortmodule/_logger.py", line 161, in wrapper
result = func(graph_execution_manager, *args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/onnxruntime/training/ortmodule/_training_manager.py", line 341, in _build_graph
super()._build_graph(graph_transformer_config)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/onnxruntime/training/ortmodule/_graph_execution_manager.py", line 182, in _build_graph
self.graph_builder.build(config)
RuntimeError: /workspace/onnxruntime/orttraining/orttraining/core/graph/gradient_builder_base.h:123 onnxruntime::training::ArgDef onnxruntime::training::GradientBuilderBase::O(size_t, bool) const i < node->OutputDefs().size() was false.
The assertion error is on the operator "/_original_module/transformer/h.0/attn/Dropout_output_0"
onnx rt module expects 2 output layer but there is only one.
/_original_module/transformer/h.0/attn/Dropout_output_0
The issue is not observed when attention dropout is not used.
config.attn_pdrop = 0 at https://github.com/huggingface/optimum/blob/main/examples/onnxruntime/training/language-modeling/run_clm.py#L435
To reproduce
Steps to to repro the issue:
Clone https://github.com/huggingface/optimum
python ./examples/onnxruntime/training/language-modeling/run_clm.py --model_name_or_path gpt2 --dataset_name wikitext --dataset_config_name wikitext-2-raw-v1 --label_smoothing 0.1 --max_steps 150 --logging_steps 1 --logging_dir log --per_device_train_batch_size 4 --per_device_eval_batch_size 4 --output_dir output --overwrite_output_dir --skip_memory_metrics --fp16 --do_train --do_eval
Urgency
No response
ONNX Runtime Installation
Built from Source
ONNX Runtime Version or Commit ID
PyTorch Version
2.3
Execution Provider
ROCm
Execution Provider Library Version
ROCm 6.2