Skip to content

Commit a470da1

Browse files
authored
[CoreML] Add support for int64 (#24462)
### Description Add int64 as a supported datatype for moving nodes to the CoreML EP. We already convert constants automatically from int64 to int32 for CoreML by calling narrow. Adding the conversion for outputs as well. ### Motivation and Context - More nodes supported on CoreML ### Note on the Unsqueeze op According to #22975 there is a bug with the Unsqueeze op with scalar inputs on x86. I was running into a bug for unsqueezes that unsqueezed a scalar input to a tensor of shape [1] since CoreML doesn't support scalar values for MLProgram. I adapted the HandleX86ArchUnsqueeze method but alternatively, can replace with an identity operator or add some additional checks. I went with adapting the HandleX86ArchUnsqueeze method since it seemed like the fastest solution.
1 parent 838b97e commit a470da1

File tree

12 files changed

+55
-68
lines changed

12 files changed

+55
-68
lines changed

onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,7 @@ Status ArgMaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
4141
AddOperationInput(*op, "axis", model_builder.AddScalarConstant(op->type(), "axis", axis));
4242
AddOperationInput(*op, "keep_dims", model_builder.AddScalarConstant(op->type(), "keep_dims", bool(keepdims)));
4343

44-
int32_t output_datatype = ONNX_NAMESPACE::TensorProto_DataType_INT32;
45-
// the output of ArgMax must be int32
46-
AddOperationOutput(*op, *node.OutputDefs()[0], output_datatype);
44+
AddOperationOutput(*op, *node.OutputDefs()[0]);
4745
model_builder.AddOperation(std::move(op));
4846
} else {
4947
auto* coreml_argmax = layer->mutable_argmax();

onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,9 @@ bool BaseOpBuilder::IsInputDtypeSupport(const Node& node, size_t idx,
115115
}
116116

117117
#if CAN_BUILD_COREML6_OR_LATER
118-
// only MLProgram support FP16
119-
if (input_params.create_mlprogram && input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
118+
// only MLProgram support FP16 and INT64
119+
if (input_params.create_mlprogram && (input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16 ||
120+
input_type == ONNX_NAMESPACE::TensorProto_DataType_INT64)) {
120121
return true;
121122
}
122123
#endif

onnxruntime/core/providers/coreml/builders/impl/binary_op_builder.cc

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,17 @@ bool CheckIfBothInputShapesMatch(const Node& node, const logging::Logger& logger
5454
y_shape_proto->dim().begin(), y_shape_proto->dim().end(),
5555
dim_eq);
5656
}
57+
58+
bool ShouldUseFloorDiv(const Node& node, const logging::Logger& logger) {
59+
// since ONNX spec requires both inputs to have the same type, we only need
60+
// to check the first input type
61+
const auto& input0 = *node.InputDefs()[0];
62+
int32_t input_type0 = ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED;
63+
GetType(input0, input_type0, logger);
64+
65+
return input_type0 == ONNX_NAMESPACE::TensorProto_DataType_INT32 ||
66+
input_type0 == ONNX_NAMESPACE::TensorProto_DataType_INT64;
67+
}
5768
} // namespace
5869

5970
static std::vector<int64_t> InferOutputShape(const std::vector<int64_t>& a, const std::vector<int64_t>& b) {
@@ -131,9 +142,13 @@ Status BinaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
131142
} else if (op_type == "Sub") {
132143
coreml_op_type = "sub";
133144
} else if (op_type == "Div") {
134-
// we support fp32/fp16 currently. when we add support for integers we need to check the type and use
135-
// "floor_div" or "real_div" accordingly
136-
coreml_op_type = "real_div";
145+
// Use "floor_div" op for integer division (int32 or int64)
146+
// use "real_div" for float division (fp16 or fp32)
147+
if (ShouldUseFloorDiv(node, logger)) {
148+
coreml_op_type = "floor_div";
149+
} else {
150+
coreml_op_type = "real_div";
151+
}
137152
} else if (op_type == "Pow") {
138153
coreml_op_type = "pow";
139154
} else {

onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -261,9 +261,10 @@ MILSpec::DataType OnnxDataTypeToMILSpec(int onnx_type) {
261261
case ONNX_NAMESPACE::TensorProto_DataType_INT16:
262262
return MILSpec::DataType::INT16;
263263
case ONNX_NAMESPACE::TensorProto_DataType_INT32:
264-
return MILSpec::DataType::INT32;
265264
case ONNX_NAMESPACE::TensorProto_DataType_INT64:
266-
return MILSpec::DataType::INT64;
265+
// CoreML only supports int32 for its operations and can only produce int32 values so
266+
// we convert any int64 to int32.
267+
return MILSpec::DataType::INT32;
267268

268269
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
269270
return MILSpec::DataType::UINT8;
@@ -367,19 +368,15 @@ void AddIntermediateOperationOutput(COREML_SPEC::MILSpec::Operation& op, std::st
367368
SetTensorTypeInfo(tensor_type, OnnxDataTypeToMILSpec(element_type), shape, /*convert_scalar*/ true);
368369
}
369370

370-
void AddOperationOutput(COREML_SPEC::MILSpec::Operation& op, const NodeArg& output,
371-
std::optional<int32_t> override_element_type) {
371+
void AddOperationOutput(COREML_SPEC::MILSpec::Operation& op, const NodeArg& output) {
372372
auto& outputs = *op.mutable_outputs();
373373
auto& output_arg = *outputs.Add();
374374
output_arg.set_name(output.Name());
375375

376376
MILSpec::ValueType& value = *output_arg.mutable_type();
377377
MILSpec::TensorType& tensor_type = *value.mutable_tensortype();
378378

379-
auto elem_type = override_element_type ? *override_element_type
380-
: output.TypeAsProto()->tensor_type().elem_type();
381-
382-
SetTensorTypeInfo(tensor_type, OnnxDataTypeToMILSpec(elem_type), output.Shape(), /*convert_scalar*/ true);
379+
SetTensorTypeInfo(tensor_type, OnnxDataTypeToMILSpec(output.TypeAsProto()->tensor_type().elem_type()), output.Shape(), /*convert_scalar*/ true);
383380
}
384381

385382
void AddPadTypeAndPads(COREML_SPEC::MILSpec::Operation& op, ModelBuilder& model_builder, std::string_view op_type,

onnxruntime/core/providers/coreml/builders/impl/builder_utils.h

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ COREML_SPEC::MILSpec::DataType DataTypeToMILSpec() {
9898

9999
// The TensorProto.data_type field is an int, but must be a valid TensorProto_DataType value.
100100
// Use int for the arg so the caller can pass TensorProto.data_type() value and do the cast to enum internally
101+
// This method also automatically converts int64 to int32 since only int32 is supported for CoreML operations.
101102
COREML_SPEC::MILSpec::DataType OnnxDataTypeToMILSpec(int onnx_type);
102103

103104
/// <summary>
@@ -156,12 +157,7 @@ void AddIntermediateOperationOutput(COREML_SPEC::MILSpec::Operation& op, std::st
156157
/// </summary>
157158
/// <param name="op">Operation to update.</param>
158159
/// <param name="output">NodeArg with details of output to add.</param>
159-
/// <param name="override_element_type">
160-
/// Override the element type. Only set to handle cases where we believe the data at runtime will be int32 but
161-
/// the original ONNX node has type int64.
162-
/// </param>
163-
void AddOperationOutput(COREML_SPEC::MILSpec::Operation& op, const NodeArg& output,
164-
std::optional<int32_t> override_element_type = std::nullopt);
160+
void AddOperationOutput(COREML_SPEC::MILSpec::Operation& op, const NodeArg& output);
165161

166162
/// <summary>
167163
/// Add pad_type and pad values.

onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ Status CastOpBuilder::AddToModelBuilderImpl([[maybe_unused]] ModelBuilder& model
4444
// CoreML operators can only produce int32 and not int64 values.
4545
// Due to that there should be no actual int64 values inside the CoreML model and we can infer any
4646
// ONNX_NAMESPACE::TensorProto::INT64 values to be int32.
47-
cast_to_type = ONNX_NAMESPACE::TensorProto::INT32;
4847
} else if (cast_to_type == ONNX_NAMESPACE::TensorProto::FLOAT) {
4948
to_dtype = "fp32";
5049
} else if (cast_to_type == ONNX_NAMESPACE::TensorProto::FLOAT16) {
@@ -69,7 +68,7 @@ Status CastOpBuilder::AddToModelBuilderImpl([[maybe_unused]] ModelBuilder& model
6968
if (op_type == "cast") {
7069
AddOperationInput(*op, "dtype", model_builder.AddScalarConstant(op->type(), "dtype", std::string(to_dtype)));
7170
}
72-
AddOperationOutput(*op, *node.OutputDefs()[0], cast_to_type);
71+
AddOperationOutput(*op, *node.OutputDefs()[0]);
7372
model_builder.AddOperation(std::move(op));
7473
}
7574

onnxruntime/core/providers/coreml/builders/impl/gather_op_builder.cc

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,23 +35,14 @@ Status GatherOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
3535
using CoreML::Specification::MILSpec::Operation;
3636
std::unique_ptr<Operation> op = model_builder.CreateOperation(node, "gather");
3737

38-
std::optional<int32_t> output_datatype;
39-
40-
int32_t input_type;
41-
ORT_RETURN_IF_NOT(GetType(*node.InputDefs()[0], input_type, logger), "Failed to get input type");
42-
43-
if (input_type == ONNX_NAMESPACE::TensorProto_DataType_INT64) {
44-
output_datatype = ONNX_NAMESPACE::TensorProto_DataType_INT32;
45-
}
46-
4738
const auto axis = GetAxisAttribute(node);
4839
// coreml docs claims validate_indices is optional but in practice it is required
4940
const auto validate_indices = false;
5041
AddOperationInput(*op, "x", node.InputDefs()[0]->Name()); // data
5142
AddOperationInput(*op, "indices", node.InputDefs()[1]->Name()); // indices
5243
AddOperationInput(*op, "axis", model_builder.AddScalarConstant(op->type(), "axis", axis)); // axis attr
5344
AddOperationInput(*op, "validate_indices", model_builder.AddScalarConstant(op->type(), "validate_indices", validate_indices));
54-
AddOperationOutput(*op, *node.OutputDefs()[0], output_datatype); // output
45+
AddOperationOutput(*op, *node.OutputDefs()[0]); // output
5546
model_builder.AddOperation(std::move(op));
5647
} else {
5748
auto layer = model_builder.CreateNNLayer(node);

onnxruntime/core/providers/coreml/builders/impl/pad_op_builder.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,14 @@ bool PadOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParam
150150
LOGS(logger, VERBOSE) << "constant_value must be a constant initializer.";
151151
return false;
152152
}
153+
154+
int32_t constant_value_type;
155+
GetType(*input_defs[2], constant_value_type, logger);
156+
157+
if (constant_value_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
158+
LOGS(logger, VERBOSE) << "Only float constant_value is supported, got type: " << constant_value_type;
159+
return false;
160+
}
153161
}
154162

155163
{

onnxruntime/core/providers/coreml/builders/impl/shape_op_builder.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,10 @@ Status ShapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
5656
std::vector<int64_t> sizes = {size};
5757
AddOperationInput(*slice_op, "begin", model_builder.AddConstant(slice_op->type(), "begin", starts));
5858
AddOperationInput(*slice_op, "size", model_builder.AddConstant(slice_op->type(), "size", sizes));
59-
AddOperationOutput(*slice_op, *node.OutputDefs()[0], output_datatype);
59+
AddOperationOutput(*slice_op, *node.OutputDefs()[0]);
6060
model_builder.AddOperation(std::move(slice_op));
6161
} else {
62-
AddOperationOutput(*op, *node.OutputDefs()[0], output_datatype);
62+
AddOperationOutput(*op, *node.OutputDefs()[0]);
6363
model_builder.AddOperation(std::move(op));
6464
}
6565
} else {
@@ -127,7 +127,8 @@ bool ShapeOpBuilder::HasSupportedInputsImpl(const Node& node,
127127
if (input_params.create_mlprogram) {
128128
if ((input_type == ONNX_NAMESPACE::TensorProto_DataType_INT32 ||
129129
input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT ||
130-
input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16)) {
130+
input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16 ||
131+
input_type == ONNX_NAMESPACE::TensorProto_DataType_INT64)) {
131132
return true;
132133
} else {
133134
LOGS(logger, VERBOSE) << "[" << node.OpType()

onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -143,21 +143,6 @@ Status SliceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
143143
}
144144
}
145145

146-
// Int32, float and float16 are supported by CoreML slice_by_index.
147-
// We convert any int64 model input to int32 when running the CoreML model for the partition.
148-
// Any other integer data created at runtime is the output from CoreML operations, and should int32 not int64.
149-
// Based on that, we assume that the actual input when running will be int32, so we override the output data
150-
// type to reflect this.
151-
// If we were to leave it as TensorProto_DataType_INT64 the CoreML model would be invalid.
152-
std::optional<int32_t> output_datatype;
153-
154-
int32_t input_type;
155-
ORT_RETURN_IF_NOT(GetType(*node.InputDefs()[0], input_type, logger), "Failed to get input type");
156-
157-
if (input_type == ONNX_NAMESPACE::TensorProto_DataType_INT64) {
158-
output_datatype = ONNX_NAMESPACE::TensorProto_DataType_INT32;
159-
}
160-
161146
auto op = model_builder.CreateOperation(node, "slice_by_index");
162147

163148
auto begin = model_builder.AddConstant(op->type(), "begin", AsSpan(compute_metadata.starts_));
@@ -173,7 +158,7 @@ Status SliceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
173158
AddOperationInput(*op, "begin_mask", begin_mask);
174159
AddOperationInput(*op, "end_mask", end_mask);
175160

176-
AddOperationOutput(*op, *output_defs[0], output_datatype);
161+
AddOperationOutput(*op, *output_defs[0]);
177162

178163
model_builder.AddOperation(std::move(op));
179164

0 commit comments

Comments
 (0)