Skip to content

Commit

Permalink
Add bfloat16 to 47 ops
Browse files Browse the repository at this point in the history
Conversion to/from 21<->22 also implemented

Fixes onnx#3842

Signed-off-by: Thiago Crepaldi <thiagofc@microsoft.com>
  • Loading branch information
thiagocrepaldi committed Apr 17, 2024
1 parent c459890 commit 1f081eb
Show file tree
Hide file tree
Showing 255 changed files with 7,281 additions and 865 deletions.
2,557 changes: 2,403 additions & 154 deletions docs/Changelog.md

Large diffs are not rendered by default.

380 changes: 220 additions & 160 deletions docs/Operators.md

Large diffs are not rendered by default.

Binary file modified onnx/backend/test/data/node/test_acos/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_acos_example/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_acosh/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_acosh_example/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_asin/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_asin_example/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_asinh/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_asinh_example/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_atan/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_atan_example/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_atanh/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_atanh_example/model.onnx
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_averagepool_2d_ceil/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_averagepool_2d_pads/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_bernoulli/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_bernoulli_double/model.onnx
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_bernoulli_expanded/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_bernoulli_seed/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_convtranspose/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_convtranspose_1d/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_convtranspose_3d/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_convtranspose_pad/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_convtranspose_pads/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_cos/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_cos_example/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_cosh/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_cosh_example/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_det_2d/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_det_nd/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_dropout_default/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_dropout_default_mask/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_elu/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_elu_default/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_elu_example/model.onnx
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_eyelike_with_dtype/model.onnx
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_globalaveragepool/model.onnx
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_gridsample/model.onnx
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_gridsample_bicubic/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_gridsample_bilinear/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_gridsample_nearest/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_hardsigmoid/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_hardsigmoid_default/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_hardsigmoid_example/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_hardswish/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_hardswish_expanded/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_lppool_1d_default/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_lppool_2d_default/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_lppool_2d_dilations/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_lppool_2d_pads/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_lppool_2d_same_lower/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_lppool_2d_same_upper/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_lppool_2d_strides/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_lppool_3d_default/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_maxpool_1d_default/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_maxpool_2d_ceil/model.onnx
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_maxpool_2d_default/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_maxpool_2d_dilations/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_maxpool_2d_pads/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_maxpool_2d_strides/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_maxpool_2d_uint8/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_maxpool_3d_default/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_maxpool_3d_dilations/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_mish/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_mish_expanded/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_nllloss_NC/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_nllloss_NC_expanded/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_nllloss_NCd1/model.onnx
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_nllloss_NCd1_ii/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_nllloss_NCd1_weight/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_nllloss_NCd1d2/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_roialign_mode_max/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_round/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_selu/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_selu_default/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_selu_example/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_sin/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_sin_example/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_sinh/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_sinh_example/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_softplus/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_softplus_example/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_softsign/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_softsign_example/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_tan/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_tan_example/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_thresholdedrelu/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_training_dropout/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
9 changes: 2 additions & 7 deletions onnx/backend/test/stat_coverage.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,8 @@ def gen_node_test_coverage(
f.write("Node tests have covered 0/0 (N/A) common operators. \n\n")
if num_experimental:
f.write(
"Node tests have covered {}/{} ({:.2f}%, {} generators excluded) " # noqa: UP032
"experimental operators.\n\n".format(
len(experimental_covered),
num_experimental,
(len(experimental_covered) / float(num_experimental) * 100),
len(experimental_generator),
)
f"Node tests have covered {len(experimental_covered)}/{num_experimental} ({len(experimental_covered) / float(num_experimental) * 100:.2f}%, {len(experimental_generator)} generators excluded) "
"experimental operators.\n\n"
)
else:
f.write("Node tests have covered 0/0 (N/A) experimental operators.\n\n")
Expand Down
8 changes: 1 addition & 7 deletions onnx/defs/gen_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,13 +358,7 @@ def main(args: Args) -> None:
if function_ops:
fout.write("|**Function**|**Since version**|**Function version**|\n")
for n, schema, versions, function_versions in function_ops:
s = '|{}<a href="#{}">{}</a>|{}|{}|\n'.format( # noqa: UP032
support_level_str(schema.support_level),
format_name_with_domain(domain, n),
format_name_with_domain(domain, n),
format_versions(versions, args.changelog),
format_function_versions(function_versions),
)
s = f'|{support_level_str(schema.support_level)}<a href="#{format_name_with_domain(domain, n)}">{format_name_with_domain(domain, n)}</a>|{format_versions(versions, args.changelog)}|{format_function_versions(function_versions)}|\n'
fout.write(s)

fout.write("\n")
Expand Down
116 changes: 32 additions & 84 deletions onnx/defs/generator/defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ ONNX_OPERATOR_SET_SCHEMA(
}
}));

static const char* EyeLike_ver9_doc = R"DOC(
static const char* EyeLike_ver22_doc = R"DOC(
Generate a 2D tensor (matrix) with ones on the diagonal and zeros everywhere else. Only 2D
tensors are supported, i.e. input T1 must be of rank 2. The shape of the output tensor is the
same as the input tensor. The data type can be specified by the 'dtype' argument. If
Expand All @@ -138,9 +138,9 @@ TensorProto message and be valid as an output type.

ONNX_OPERATOR_SET_SCHEMA(
EyeLike,
9,
22,
OpSchema()
.SetDoc(EyeLike_ver9_doc)
.SetDoc(EyeLike_ver22_doc)
.Attr(
"k",
"(Optional) Index of the diagonal to be populated with ones. Default is 0."
Expand All @@ -159,33 +159,11 @@ ONNX_OPERATOR_SET_SCHEMA(
.Output(0, "output", "Output tensor, same shape as input tensor T1.", "T2")
.TypeConstraint(
"T1",
{"tensor(float16)",
"tensor(float)",
"tensor(double)",
"tensor(int8)",
"tensor(int16)",
"tensor(int32)",
"tensor(int64)",
"tensor(uint8)",
"tensor(uint16)",
"tensor(uint32)",
"tensor(uint64)",
"tensor(bool)"},
OpSchema::all_non_complex_numeric_types_plus_bool_ir4(),
"Constrain input types. Strings and complex are not supported.")
.TypeConstraint(
"T2",
{"tensor(float16)",
"tensor(float)",
"tensor(double)",
"tensor(int8)",
"tensor(int16)",
"tensor(int32)",
"tensor(int64)",
"tensor(uint8)",
"tensor(uint16)",
"tensor(uint32)",
"tensor(uint64)",
"tensor(bool)"},
OpSchema::all_non_complex_numeric_types_plus_bool_ir4(),
"Constrain output types. Strings and complex are not supported.")
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
if (ctx.getAttribute("dtype") != nullptr) {
Expand All @@ -202,7 +180,7 @@ ONNX_OPERATOR_SET_SCHEMA(
propagateShapeFromInputToOutput(ctx, 0, 0);
}));

static const char* RandomUniform_ver1_doc = R"DOC(
static const char* RandomUniform_ver22_doc = R"DOC(
Generate a tensor with random values drawn from a uniform distribution. The shape
of the tensor is specified by the `shape` argument and the range by `low` and `high`.
Expand All @@ -213,9 +191,9 @@ TensorProto message.

ONNX_OPERATOR_SET_SCHEMA(
RandomUniform,
1,
22,
OpSchema()
.SetDoc(RandomUniform_ver1_doc)
.SetDoc(RandomUniform_ver22_doc)
.Attr("low", "Lower boundary of the output values.", AttributeProto::FLOAT, 0.0f)
.Attr("high", "Upper boundary of the output values.", AttributeProto::FLOAT, 1.0f)
.Attr(
Expand All @@ -230,16 +208,13 @@ ONNX_OPERATOR_SET_SCHEMA(
static_cast<int64_t>(TensorProto::FLOAT))
.Attr("shape", "The shape of the output tensor.", AttributeProto::INTS)
.Output(0, "output", "Output tensor of random values drawn from uniform distribution", "T")
.TypeConstraint(
"T",
{"tensor(float16)", "tensor(float)", "tensor(double)"},
"Constrain output types to float tensors.")
.TypeConstraint("T", OpSchema::all_float_types_ir4(), "Constrain output types to float tensors.")
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
propagateElemTypeFromAttributeToOutput(ctx, "dtype", 0, TensorProto::FLOAT);
propagateShapeFromAttributeToOutput(ctx, "shape", 0);
}));

static const char* RandomNormal_ver1_doc = R"DOC(
static const char* RandomNormal_ver22_doc = R"DOC(
Generate a tensor with random values drawn from a normal distribution. The shape
of the tensor is specified by the `shape` argument and the parameter of the normal distribution
specified by `mean` and `scale`.
Expand All @@ -251,9 +226,9 @@ TensorProto message.

ONNX_OPERATOR_SET_SCHEMA(
RandomNormal,
1,
22,
OpSchema()
.SetDoc(RandomNormal_ver1_doc)
.SetDoc(RandomNormal_ver22_doc)
.Attr("mean", "The mean of the normal distribution.", AttributeProto::FLOAT, 0.0f)
.Attr("scale", "The standard deviation of the normal distribution.", AttributeProto::FLOAT, 1.0f)
.Attr(
Expand All @@ -268,16 +243,13 @@ ONNX_OPERATOR_SET_SCHEMA(
static_cast<int64_t>(TensorProto::FLOAT))
.Attr("shape", "The shape of the output tensor.", AttributeProto::INTS)
.Output(0, "output", "Output tensor of random values drawn from normal distribution", "T")
.TypeConstraint(
"T",
{"tensor(float16)", "tensor(float)", "tensor(double)"},
"Constrain output types to float tensors.")
.TypeConstraint("T", OpSchema::all_float_types_ir4(), "Constrain output types to float tensors.")
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
propagateElemTypeFromAttributeToOutput(ctx, "dtype", 0, TensorProto::FLOAT);
propagateShapeFromAttributeToOutput(ctx, "shape", 0);
}));

static const char* RandomUniformLike_ver1_doc = R"DOC(
static const char* RandomUniformLike_ver22_doc = R"DOC(
Generate a tensor with random values drawn from a uniform distribution.
The shape of the output tensor is copied from the shape of the input tensor,
and the parameters of the uniform distribution are specified by `low` and `high`.
Expand All @@ -289,9 +261,9 @@ TensorProto message and be valid as an output type.

ONNX_OPERATOR_SET_SCHEMA(
RandomUniformLike,
1,
22,
OpSchema()
.SetDoc(RandomUniformLike_ver1_doc)
.SetDoc(RandomUniformLike_ver22_doc)
.Attr("low", "Lower boundary of the output values.", AttributeProto::FLOAT, 0.0f)
.Attr("high", "Upper boundary of the output values.", AttributeProto::FLOAT, 1.0f)
.Attr(
Expand All @@ -309,12 +281,9 @@ ONNX_OPERATOR_SET_SCHEMA(
.Output(0, "output", "Output tensor of random values drawn from uniform distribution", "T2")
.TypeConstraint(
"T1",
OpSchema::all_tensor_types(),
OpSchema::all_tensor_types_ir4(),
"Constrain to any tensor type. If the dtype attribute is not provided this must be a valid output type.")
.TypeConstraint(
"T2",
{"tensor(float16)", "tensor(float)", "tensor(double)"},
"Constrain output types to float tensors.")
.TypeConstraint("T2", OpSchema::all_float_types_ir10(), "Constrain output types to float tensors.")
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
if (ctx.getAttribute("dtype") != nullptr)
propagateElemTypeFromAttributeToOutput(ctx, "dtype", 0);
Expand All @@ -326,7 +295,7 @@ ONNX_OPERATOR_SET_SCHEMA(
propagateShapeFromInputToOutput(ctx, 0, 0);
}));

static const char* RandomNormalLike_ver1_doc = R"DOC(
static const char* RandomNormalLike_ver22_doc = R"DOC(
Generate a tensor with random values drawn from a normal distribution.
The shape of the output tensor is copied from the shape of the input tensor,
and the parameters of the normal distribution are specified by `mean` and `scale`.
Expand All @@ -338,9 +307,9 @@ TensorProto message, and be valid as an output type.

ONNX_OPERATOR_SET_SCHEMA(
RandomNormalLike,
1,
22,
OpSchema()
.SetDoc(RandomNormalLike_ver1_doc)
.SetDoc(RandomNormalLike_ver22_doc)
.Attr("mean", "The mean of the normal distribution.", AttributeProto::FLOAT, 0.0f)
.Attr("scale", "The standard deviation of the normal distribution.", AttributeProto::FLOAT, 1.0f)
.Attr(
Expand All @@ -358,12 +327,9 @@ ONNX_OPERATOR_SET_SCHEMA(
.Output(0, "output", "Output tensor of random values drawn from normal distribution", "T2")
.TypeConstraint(
"T1",
OpSchema::all_tensor_types(),
OpSchema::all_tensor_types_ir4(),
"Constrain to any tensor type. If the dtype attribute is not provided this must be a valid output type.")
.TypeConstraint(
"T2",
{"tensor(float16)", "tensor(float)", "tensor(double)"},
"Constrain output types to float tensors.")
.TypeConstraint("T2", OpSchema::all_float_types_ir10(), "Constrain output types to float tensors.")
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
if (ctx.getAttribute("dtype") != nullptr)
propagateElemTypeFromAttributeToOutput(ctx, "dtype", 0);
Expand All @@ -375,16 +341,16 @@ ONNX_OPERATOR_SET_SCHEMA(
propagateShapeFromInputToOutput(ctx, 0, 0);
}));

static const char* Multinomial_ver7_doc = R"DOC(
static const char* Multinomial_ver22_doc = R"DOC(
Generate a tensor of samples from a multinomial distribution according to the probabilities
of each of the possible outcomes.
)DOC";

ONNX_OPERATOR_SET_SCHEMA(
Multinomial,
7,
22,
OpSchema()
.SetDoc(Multinomial_ver7_doc)
.SetDoc(Multinomial_ver22_doc)
.Attr("sample_size", "Number of times to sample.", AttributeProto::INT, static_cast<int64_t>(1))
.Attr(
"seed",
Expand All @@ -406,10 +372,7 @@ ONNX_OPERATOR_SET_SCHEMA(
"output",
"Output tensor with shape [batch_size, sample_size], where sample_size is the number of times to sample. Each value along the axis zero represents the outcome of the corresponding sample in a batch.",
"T2")
.TypeConstraint(
"T1",
{"tensor(float16)", "tensor(float)", "tensor(double)"},
"Constrain input types to float tensors.")
.TypeConstraint("T1", OpSchema::all_float_types_ir4(), "Constrain input types to float tensors.")
.TypeConstraint("T2", {"tensor(int32)", "tensor(int64)"}, "Constrain output types to integral tensors.")
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
auto dtype = ctx.getAttribute("dtype");
Expand Down Expand Up @@ -562,7 +525,7 @@ ONNX_OPERATOR_SET_SCHEMA(
}
}));

static const char* Bernoulli_ver15_doc = R"DOC(
static const char* Bernoulli_ver22_doc = R"DOC(
Draws binary random numbers (0 or 1) from a Bernoulli distribution. The input tensor should be a tensor
containing probabilities p (a value in the range [0,1]) to be used for drawing the binary random number,
where an output of 1 is produced with probability p and an output of 0 is produced with probability (1-p).
Expand All @@ -573,9 +536,9 @@ implementations (even if a seed is specified).

ONNX_OPERATOR_SET_SCHEMA(
Bernoulli,
15,
22,
OpSchema()
.SetDoc(Bernoulli_ver15_doc)
.SetDoc(Bernoulli_ver22_doc)
.Attr(
"seed",
"(Optional) Seed to the random generator, if not specified we will auto generate one.",
Expand All @@ -589,25 +552,10 @@ ONNX_OPERATOR_SET_SCHEMA(
OPTIONAL_VALUE)
.Input(0, "input", "All values in input have to be in the range:[0, 1].", "T1")
.Output(0, "output", "The returned output tensor only has values 0 or 1, same shape as input tensor.", "T2")
.TypeConstraint(
"T1",
{"tensor(float16)", "tensor(float)", "tensor(double)"},
"Constrain input types to float tensors.")
.TypeConstraint("T1", OpSchema::all_float_types_ir4(), "Constrain input types to float tensors.")
.TypeConstraint(
"T2",
{"tensor(float16)",
"tensor(float)",
"tensor(double)",
"tensor(bfloat16)",
"tensor(uint8)",
"tensor(uint16)",
"tensor(uint32)",
"tensor(uint64)",
"tensor(int8)",
"tensor(int16)",
"tensor(int32)",
"tensor(int64)",
"tensor(bool)"},
OpSchema::all_non_complex_numeric_types_plus_bool_ir4(),
"Constrain output types to all numeric tensors and bool tensors.")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
if (ctx.getAttribute("dtype") != nullptr)
Expand Down
Loading

0 comments on commit 1f081eb

Please sign in to comment.