Skip to content
Permalink
Browse files Browse the repository at this point in the history
Prevent memory overflow in ParseAttrValue from nested tensors.
PiperOrigin-RevId: 370108442
Change-Id: I84d64a5e8895a6aeffbf4749841b4c54d51b5889
  • Loading branch information
pak-laura authored and tensorflower-gardener committed Apr 23, 2021
1 parent c22d88d commit e07e1c3
Showing 1 changed file with 57 additions and 1 deletion.
58 changes: 57 additions & 1 deletion tensorflow/core/framework/attr_value_util.cc
Expand Up @@ -38,6 +38,9 @@ namespace {
// Do not construct large tensors to compute their hash or compare for equality.
constexpr int kMaxAttrValueTensorByteSize = 32 * 1024 * 1024; // 32mb

// Limit nesting of tensors to 100 deep to prevent memory overflow.
constexpr int kMaxTensorNestDepth = 100;

// Return the size of the tensor represented by this TensorProto. If shape is
// not fully defined return -1.
int64 TensorByteSize(const TensorProto& t) {
Expand Down Expand Up @@ -224,6 +227,54 @@ string SummarizeFunc(const NameAttrList& func) {
return strings::StrCat(func.name(), "[", absl::StrJoin(entries, ", "), "]");
}

bool ParseAttrValueHelper_TensorNestsUnderLimit(int limit, string to_parse) {
int nests = 0;
int maxed_out = to_parse.length();
int open_curly = to_parse.find('{');
int open_bracket = to_parse.find('<');
int close_curly = to_parse.find('}');
int close_bracket = to_parse.find('>');
if (open_curly == -1) {
open_curly = maxed_out;
}
if (open_bracket == -1) {
open_bracket = maxed_out;
}
int min = std::min(open_curly, open_bracket);
do {
if (open_curly == maxed_out && open_bracket == maxed_out) {
return true;
}
if (min == open_curly) {
nests += 1;
open_curly = to_parse.find('{', open_curly + 1);
if (open_curly == -1) {
open_curly = maxed_out;
}
} else if (min == open_bracket) {
nests += 1;
open_bracket = to_parse.find('<', open_bracket + 1);
if (open_bracket == -1) {
open_bracket = maxed_out;
}
} else if (min == close_curly) {
nests -= 1;
close_curly = to_parse.find('}', close_curly + 1);
if (close_curly == -1) {
close_curly = maxed_out;
}
} else if (min == close_bracket) {
nests -= 1;
close_bracket = to_parse.find('>', close_bracket + 1);
if (close_bracket == -1) {
close_bracket = maxed_out;
}
}
min = std::min({open_curly, open_bracket, close_curly, close_bracket});
} while (nests < 100);
return false;
}

} // namespace

string SummarizeAttrValue(const AttrValue& attr_value) {
Expand Down Expand Up @@ -448,7 +499,12 @@ bool ParseAttrValue(StringPiece type, StringPiece text, AttrValue* out) {
} else {
to_parse = strings::StrCat(field_name, ": ", text);
}

if (field_name == "tensor") {
if (!ParseAttrValueHelper_TensorNestsUnderLimit(kMaxTensorNestDepth,
to_parse)) {
return false;
}
}
return ProtoParseFromString(to_parse, out);
}

Expand Down

0 comments on commit e07e1c3

Please sign in to comment.