Skip to content

Commit f0d3c33

Browse files
authored
[QNN EP] Add Einsum support for some equations (#24616)
[QNN EP] Add Einsum support for some equations. Intend is not to support all equations. But to enable case by case to improve performance.
1 parent 1c8130e commit f0d3c33

File tree

8 files changed

+754
-9
lines changed

8 files changed

+754
-9
lines changed

onnxruntime/core/providers/qnn/builder/op_builder_factory.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,10 @@ OpBuilderRegistrations::OpBuilderRegistrations() {
174174
CreateExpandOpBuilder("Expand", *this);
175175
}
176176

177+
{
178+
CreateEinsumOpBuilder("Einsum", *this);
179+
}
180+
177181
{
178182
CreateMatMulOpBuilder("MatMul", *this);
179183
}

onnxruntime/core/providers/qnn/builder/op_builder_factory.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,5 +100,7 @@ void CreateExpandOpBuilder(const std::string& op_type, OpBuilderRegistrations& o
100100
void CreateHardSigmoidOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
101101

102102
void CreateMatMulOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
103+
104+
void CreateEinsumOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
103105
} // namespace qnn
104106
} // namespace onnxruntime

onnxruntime/core/providers/qnn/builder/opbuilder/clip_op_builder.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,13 +94,13 @@ Status ClipOpBuilder::ExplictOpCheck(QnnModelWrapper& qnn_model_wrapper, const N
9494
if (node_unit.Inputs().size() > 1) {
9595
const auto& min_input_name = node_unit.Inputs()[1].node_arg.Name();
9696
if (!min_input_name.empty() && !qnn_model_wrapper.IsConstantInput(min_input_name)) {
97-
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN desn't support dynamic min/max.");
97+
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN doesn't support dynamic min/max.");
9898
}
9999
}
100100
if (node_unit.Inputs().size() > 2) {
101101
const auto& max_input_name = node_unit.Inputs()[2].node_arg.Name();
102102
if (!max_input_name.empty() && !qnn_model_wrapper.IsConstantInput(max_input_name)) {
103-
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN desn't support dynamic min/max.");
103+
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN doesn't support dynamic min/max.");
104104
}
105105
}
106106
return Status::OK();

onnxruntime/core/providers/qnn/builder/opbuilder/einsum_op_builder.cc

Lines changed: 396 additions & 0 deletions
Large diffs are not rendered by default.

onnxruntime/core/providers/qnn/builder/opbuilder/slice_op_builder.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ Status SliceOpBuilder::ExplictOpCheck(QnnModelWrapper& qnn_model_wrapper, const
4646
for (size_t i = 1; i < input_count; i++) {
4747
const auto& next_input = node_unit.Inputs()[i].node_arg.Name();
4848
if (!qnn_model_wrapper.IsConstantInput(next_input)) {
49-
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN desn't support dynamic slice.");
49+
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN doesn't support dynamic slice.");
5050
}
5151
}
5252
}

onnxruntime/core/providers/qnn/builder/opbuilder/tile_op_builder.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ Status TileOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
4242
std::vector<std::string>& input_names,
4343
bool do_op_validation) const {
4444
const auto& inputs = node_unit.Inputs();
45-
// QNN Tile only support 1 input, the 2nd input need to be initialier and set as Qnn node parameter
45+
// QNN Tile only support 1 input, the 2nd input need to be initializer and set as Qnn node parameter
4646
if (do_op_validation) {
4747
auto& repeats_input_name = inputs[1].node_arg.Name();
4848
ORT_RETURN_IF_NOT(qnn_model_wrapper.IsConstantInput(repeats_input_name),
@@ -60,7 +60,7 @@ Status TileOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra
6060
const logging::Logger& logger,
6161
bool do_op_validation) const {
6262
std::vector<std::string> param_tensor_names;
63-
// Already confirmed repeats input is initailizer in ProcessInputs()
63+
// Already confirmed repeats input is initializer in ProcessInputs()
6464
const auto& repeats_input_name = node_unit.Inputs()[1].node_arg.Name();
6565

6666
std::vector<uint8_t> unpacked_tensor;

onnxruntime/core/providers/qnn/builder/qnn_model.cc

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -180,14 +180,16 @@ Status QnnModel::SetupQnnInputOutput(const logging::Logger& logger) {
180180
auto result = SetupTensors(qnn_input_infos_, graph_info_->InputTensors());
181181

182182
if (Status::OK() != result) {
183-
LOGS(logger, ERROR) << "Failed to setup QNN input output tensors for graph: " << graph_info_->Name();
184-
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to setup QNN input tensors!");
183+
const std::string message = "Failed to setup QNN input tensors for graph: " + graph_info_->Name();
184+
LOGS(logger, ERROR) << message;
185+
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, message);
185186
}
186187

187188
result = SetupTensors(qnn_output_infos_, graph_info_->OutputTensors(), false);
188189
if (Status::OK() != result) {
189-
LOGS(logger, ERROR) << "Failed to setup QNN input output tensors for graph: " << graph_info_->Name();
190-
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to setup QNN output tensors!");
190+
const std::string message = "Failed to setup QNN output tensors for graph: " + graph_info_->Name();
191+
LOGS(logger, ERROR) << message;
192+
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, message);
191193
}
192194

193195
return Status::OK();

0 commit comments

Comments
 (0)