Skip to content

Commit

Permalink
Add index check for Transpose's type inference function (onnx#1053)
Browse files Browse the repository at this point in the history
* Add index check for Transpose's type and shape inference function

* Add index check

* Update defs.cc

change call to fail_type_inference to fail_shape_inference
  • Loading branch information
snnn authored and Raymond Yang committed Jun 8, 2018
1 parent cce8f48 commit 4821f0d
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 15 deletions.
16 changes: 11 additions & 5 deletions onnx/defs/nn/defs.cc
Expand Up @@ -73,7 +73,7 @@ void convPoolTypeAndShapeInference(
std::vector<int64_t> pads;
if (getRepeatedAttribute(ctx, "pads", pads)) {
if (pads.size() != n_input_dims * 2) {
fail_shape_inference("Attribute pads has incorrect size");;
fail_shape_inference("Attribute pads has incorrect size");
}
} else {
pads.assign(n_input_dims * 2, 0);
Expand All @@ -82,7 +82,7 @@ void convPoolTypeAndShapeInference(
std::vector<int64_t> strides;
if (getRepeatedAttribute(ctx, "strides", strides)) {
if (strides.size() != n_input_dims) {
fail_shape_inference("Attribute strides has incorrect size");;
fail_shape_inference("Attribute strides has incorrect size");
}
} else {
strides.assign(n_input_dims, 1);
Expand All @@ -91,10 +91,10 @@ void convPoolTypeAndShapeInference(
std::vector<int64_t> kernel_shape;
if (getRepeatedAttribute(ctx, "kernel_shape", kernel_shape)) {
if (kernel_shape.size() != n_input_dims) {
fail_shape_inference("Attribute kernel_shape has incorrect size");;
fail_shape_inference("Attribute kernel_shape has incorrect size");
}
} else if (require_kernel_shape) {
fail_shape_inference("Attribute kernel_shape must be specified");;
fail_shape_inference("Attribute kernel_shape must be specified");
} else {
auto second_input_shape = ctx.getInputType(1)->tensor_type().shape();
for (int i = 2; i < second_input_shape.dim_size(); ++i) {
Expand Down Expand Up @@ -1107,8 +1107,14 @@ ONNX_OPERATOR_SET_SCHEMA(
int axis = static_cast<int>(getAttribute(ctx, "axis", 1));
if (axis > rank || axis < 0) {
fail_shape_inference(
"Invalid value(", axis, ") for attribute 'axis'");
"Invalid value(" , axis , ") for attribute 'axis'");
}
// TODO: is the operation defined for input-rank < 2?
updateOutputShape(
ctx,
0,
{multiplyDims(input_shape, 0, axis),
multiplyDims(input_shape, axis, rank)});
}));

static const char* LRN_ver1_doc = R"DOC(
Expand Down
4 changes: 4 additions & 0 deletions onnx/defs/shape_inference.h
Expand Up @@ -113,6 +113,9 @@ inline TensorShapeProto::Dimension operator/(TensorShapeProto::Dimension dim1, i
return result;
}

//if from >= upto_exclusive, return 1.
//Caller must make sure upto_exclusive is less than or equal to shape.size()
//Caller must make sure from>=0
inline TensorShapeProto::Dimension multiplyDims(const TensorShapeProto& shape, int from, int upto_exclusive) {
TensorShapeProto::Dimension dim;
dim.set_dim_value(1);
Expand Down Expand Up @@ -170,6 +173,7 @@ inline const TensorShapeProto& getInputShape(InferenceContext& ctx, size_t n) {
return ctx.getInputType(n)->tensor_type().shape();
}

//Caller must make sure fromDimIndex is strictly less than shape.dim_size()
inline void appendSingleDimCopiedFromInputTypeToOutputType(
InferenceContext& ctx,
size_t inputIndex,
Expand Down
40 changes: 30 additions & 10 deletions onnx/defs/tensor/defs.cc
Expand Up @@ -295,8 +295,11 @@ ONNX_OPERATOR_SET_SCHEMA(
if (!ctx.getInputType(0)->tensor_type().has_shape()) {
return;
}
const auto& splitDim =
ctx.getInputType(0)->tensor_type().shape().dim(axis);
const auto& shape = ctx.getInputType(0)->tensor_type().shape();
if (axis >= shape.dim_size()) {
fail_type_inference("Invalid value of attribute 'axis'");
}
const auto& splitDim = shape.dim(axis);
if (!splitDim.has_dim_value()) {
return;
}
Expand All @@ -311,7 +314,7 @@ ONNX_OPERATOR_SET_SCHEMA(

for (size_t i = 0; i < ctx.getNumOutputs(); i++) {
*ctx.getOutputType(i)->mutable_tensor_type()->mutable_shape() =
ctx.getInputType(0)->tensor_type().shape();
shape;
ctx.getOutputType(i)
->mutable_tensor_type()
->mutable_shape()
Expand Down Expand Up @@ -473,15 +476,32 @@ ONNX_OPERATOR_SET_SCHEMA(
if (!hasNInputShapes(ctx, 1)) {
return;
}

auto input_type = ctx.getInputType(0);
const TensorShapeProto& shape = input_type->tensor_type().shape();
std::vector<int64_t> perm;
if (!getRepeatedAttribute(ctx, "perm", perm)) {
for (int i =
ctx.getInputType(0)->tensor_type().shape().dim_size() - 1;
i >= 0;
--i) {
bool has_perm_attr = getRepeatedAttribute(ctx, "perm", perm);
if (!has_perm_attr) {
for (int i = shape.dim_size() - 1; i >= 0; --i)
perm.push_back(i);
}
} else if (!perm.empty()) {
// check if every index is valid
for (int64_t fromDimIndex : perm)
if (!(0 <= fromDimIndex && fromDimIndex < shape.dim_size())) {
std::ostringstream oss;
oss << "Invalid attribute perm {" << perm[0];
for (size_t i = 1; i != perm.size(); ++i) {
oss << ", " << perm[i];
}
oss << "}, input shape = {";
if (shape.dim_size() > 0) {
oss << shape.dim(0).dim_value();
for (int i = 1; i != shape.dim_size(); ++i) {
oss << ", " << shape.dim(i).dim_value();
}
oss << "}";
}
fail_type_inference(oss.str());
}
}

propagateElemTypeFromInputToOutput(ctx, 0, 0);
Expand Down

0 comments on commit 4821f0d

Please sign in to comment.