Skip to content
Permalink
Browse files Browse the repository at this point in the history
Prevent CHECK-fail when decoding resource handles from proto
In certain scenarios, the proto might contain tensors that have too many elements (overflow). This is a `CHECK`-fail in general, but we should prevent this, given how many CVEs caused by that we have received this year (a large fraction of 200).

PiperOrigin-RevId: 408049766
Change-Id: I2ac20b247aa8ed9110846fbdb7a0a9401f2c168c
  • Loading branch information
mihaimaruseac authored and tensorflower-gardener committed Nov 6, 2021
1 parent ce8ac88 commit 14fea66
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 8 deletions.
2 changes: 2 additions & 0 deletions tensorflow/core/framework/BUILD
Expand Up @@ -734,7 +734,9 @@ cc_library(
"//tensorflow/core/lib/core:errors",
"//tensorflow/core/lib/strings:strcat",
"//tensorflow/core/platform:casts",
"//tensorflow/core/platform:errors",
"//tensorflow/core/platform:intrusive_ptr",
"//tensorflow/core/platform:macros",
"//tensorflow/core/platform:statusor",
"//tensorflow/core/platform:tensor_coding",
"//tensorflow/core/platform:types",
Expand Down
31 changes: 24 additions & 7 deletions tensorflow/core/framework/resource_handle.cc
Expand Up @@ -17,8 +17,11 @@ limitations under the License.

#include "absl/strings/str_format.h"
#include "tensorflow/core/framework/resource_handle.pb.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/macros.h"

namespace tensorflow {

Expand All @@ -28,7 +31,15 @@ namespace tensorflow {
ResourceHandle::ResourceHandle() {}

ResourceHandle::ResourceHandle(const ResourceHandleProto& proto) {
FromProto(proto);
TF_CHECK_OK(FromProto(proto));
}

Status ResourceHandle::BuildResourceHandle(const ResourceHandleProto& proto,
ResourceHandle* out) {
if (out == nullptr)
return errors::Internal(
"BuildResourceHandle() was called with nullptr for the output");
return out->FromProto(proto);
}

ResourceHandle::~ResourceHandle() {}
Expand All @@ -46,7 +57,7 @@ void ResourceHandle::AsProto(ResourceHandleProto* proto) const {
}
}

void ResourceHandle::FromProto(const ResourceHandleProto& proto) {
Status ResourceHandle::FromProto(const ResourceHandleProto& proto) {
set_device(proto.device());
set_container(proto.container());
set_name(proto.name());
Expand All @@ -55,10 +66,16 @@ void ResourceHandle::FromProto(const ResourceHandleProto& proto) {
std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes;
for (const auto& dtype_and_shape : proto.dtypes_and_shapes()) {
DataType dtype = dtype_and_shape.dtype();
PartialTensorShape shape(dtype_and_shape.shape());
PartialTensorShape shape;
Status s = PartialTensorShape::BuildPartialTensorShape(
dtype_and_shape.shape(), &shape);
if (!s.ok()) {
return s;
}
dtypes_and_shapes.push_back(DtypeAndPartialTensorShape{dtype, shape});
}
dtypes_and_shapes_ = std::move(dtypes_and_shapes);
return Status::OK();
}

string ResourceHandle::SerializeAsString() const {
Expand All @@ -69,9 +86,7 @@ string ResourceHandle::SerializeAsString() const {

bool ResourceHandle::ParseFromString(const string& s) {
ResourceHandleProto proto;
const bool status = proto.ParseFromString(s);
if (status) FromProto(proto);
return status;
return proto.ParseFromString(s) && FromProto(proto).ok();
}

string ResourceHandle::DebugString() const {
Expand Down Expand Up @@ -140,7 +155,9 @@ bool DecodeResourceHandleList(std::unique_ptr<port::StringListDecoder> d,
if (!proto.ParseFromArray(d->Data(sizes[i]), sizes[i])) {
return false;
}
ps[i].FromProto(proto);
if (!ps[i].FromProto(proto).ok()) {
return false;
}
}
return true;
}
Expand Down
7 changes: 6 additions & 1 deletion tensorflow/core/framework/resource_handle.h
Expand Up @@ -46,6 +46,11 @@ class ResourceHandle {
ResourceHandle(const ResourceHandleProto& proto);
~ResourceHandle();

// Use this factory method if the `proto` comes from user controlled input, to
// prevent a denial of service.
static Status BuildResourceHandle(const ResourceHandleProto& proto,
ResourceHandle* out);

// Unique name for the device containing the resource.
const std::string& device() const { return device_; }

Expand Down Expand Up @@ -91,7 +96,7 @@ class ResourceHandle {

// Conversion to and from ResourceHandleProto
void AsProto(ResourceHandleProto* proto) const;
void FromProto(const ResourceHandleProto& proto);
Status FromProto(const ResourceHandleProto& proto);

// Serialization via ResourceHandleProto
std::string SerializeAsString() const;
Expand Down
40 changes: 40 additions & 0 deletions tensorflow/core/framework/tensor.cc
Expand Up @@ -537,6 +537,46 @@ TensorBuffer* FromProtoField(Allocator* a, const TensorProto& in, int64_t n) {
return buf;
}

// Separate implementation for `ResourceHandle` to handle the case when the
// proto for the resource is invalid. See `resource_handle.h` constructor and
// static factory builder.
template <>
TensorBuffer* FromProtoField<ResourceHandle>(Allocator* a,
const TensorProto& in, int64_t n) {
CHECK_GT(n, 0);
Buffer<ResourceHandle>* buf = new Buffer<ResourceHandle>(a, n);
ResourceHandle* data = buf->template base<ResourceHandle>();
if (data == nullptr) {
buf->Unref();
return nullptr;
}
const int64_t in_n = ProtoHelper<ResourceHandle>::NumElements(in);
if (in_n <= 0) {
std::fill_n(data, n, ResourceHandle());
} else {
// If tensor shape says we have n < in_n elements in the output tensor
// then make sure to only decode the first n out of the in_n elements in the
// in tensors. In all other cases, we decode all in_n elements of in and set
// the remaining elements up to n to be the default ResourceHandle() value.
const int64_t real_n = n < in_n ? n : in_n;
for (int64_t i = 0; i < real_n; ++i) {
Status s = ResourceHandle::BuildResourceHandle(in.resource_handle_val(i),
&data[i]);
if (!s.ok()) {
LOG(ERROR) << "Could not decode resource handle from proto \""
<< in.resource_handle_val(i).ShortDebugString()
<< "\", returned status: " << s.ToString();
buf->Unref();
return nullptr;
}
}
for (int64_t i = in_n; i < n; ++i) {
data[i] = ResourceHandle();
}
}
return buf;
}

template <>
TensorBuffer* FromProtoField<Variant>(Allocator* a, const TensorProto& in,
int64_t n) {
Expand Down

0 comments on commit 14fea66

Please sign in to comment.