Skip to content
Permalink
Browse files Browse the repository at this point in the history
Fix multiple vulnerabilities in tf.experimental.dlpack.to_dlpack.
We have a use after free caused by memory coruption, a segmentation fault caused by memory corruption, several memory leaks and an undefined behavior when taking the reference of a nullptr.

PiperOrigin-RevId: 332568894
Change-Id: Ife0fc05e103b35325094ae5d822ee5fdea764572
  • Loading branch information
mihaimaruseac authored and tensorflower-gardener committed Sep 19, 2020
1 parent 390611e commit 22e07fb
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 8 deletions.
28 changes: 22 additions & 6 deletions tensorflow/c/eager/dlpack.cc
Expand Up @@ -249,21 +249,36 @@ void TFE_CallDLManagedTensorDeleter(void* dlm_ptr) {
}

void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) {
auto tf_dlm_context = GetDlContext(h, status);
if (!status->status.ok()) {
return nullptr;
}

auto* tf_dlm_data = TFE_TensorHandleDevicePointer(h, status);
if (!status->status.ok()) {
return nullptr;
}

const Tensor* tensor = GetTensorFromHandle(h, status);
TF_DataType data_type = static_cast<TF_DataType>(tensor->dtype());
TensorReference tensor_ref(*tensor); // This will call buf_->Ref()

auto tf_dlm_type = GetDlDataType(data_type, status);
if (!status->status.ok()) {
return nullptr;
}

TensorReference tensor_ref(*tensor); // This will call buf_->Ref()
auto* tf_dlm_tensor_ctx = new TfDlManagedTensorCtx(tensor_ref);
tf_dlm_tensor_ctx->reference = tensor_ref;

DLManagedTensor* dlm_tensor = &tf_dlm_tensor_ctx->tensor;
dlm_tensor->manager_ctx = tf_dlm_tensor_ctx;
dlm_tensor->deleter = &DLManagedTensorDeleter;
dlm_tensor->dl_tensor.ctx = GetDlContext(h, status);
dlm_tensor->dl_tensor.ctx = tf_dlm_context;
int ndim = tensor->dims();
dlm_tensor->dl_tensor.ndim = ndim;
dlm_tensor->dl_tensor.data = TFE_TensorHandleDevicePointer(h, status);
dlm_tensor->dl_tensor.dtype = GetDlDataType(data_type, status);
dlm_tensor->dl_tensor.data = tf_dlm_data;
dlm_tensor->dl_tensor.dtype = tf_dlm_type;

std::vector<int64_t>* shape_arr = &tf_dlm_tensor_ctx->shape;
std::vector<int64_t>* stride_arr = &tf_dlm_tensor_ctx->strides;
Expand All @@ -276,13 +291,14 @@ void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) {
(*stride_arr)[i] = (*shape_arr)[i + 1] * (*stride_arr)[i + 1];
}

dlm_tensor->dl_tensor.shape = &(*shape_arr)[0];
dlm_tensor->dl_tensor.shape = shape_arr->data();
// There are two ways to represent compact row-major data
// 1) nullptr indicates tensor is compact and row-majored.
// 2) fill in the strides array as the real case for compact row-major data.
// Here we choose option 2, since some frameworks didn't handle the strides
// argument properly.
dlm_tensor->dl_tensor.strides = &(*stride_arr)[0];
dlm_tensor->dl_tensor.strides = stride_arr->data();

dlm_tensor->dl_tensor.byte_offset =
0; // TF doesn't handle the strides and byte_offsets here
return static_cast<void*>(dlm_tensor);
Expand Down
1 change: 0 additions & 1 deletion tensorflow/python/dlpack/BUILD
Expand Up @@ -19,7 +19,6 @@ cuda_py_test(
name = "dlpack_test",
srcs = ["dlpack_test.py"],
srcs_version = "PY2AND3",
tags = ["noasan"], # TODO(b/159774807)
deps = [
":dlpack",
"//tensorflow/python/eager:test",
Expand Down
8 changes: 8 additions & 0 deletions tensorflow/python/dlpack/dlpack_test.py
Expand Up @@ -20,9 +20,11 @@
from absl.testing import parameterized
import numpy as np


from tensorflow.python.dlpack import dlpack
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.platform import test
from tensorflow.python.ops import array_ops
Expand Down Expand Up @@ -105,6 +107,12 @@ def UnsupportedComplex64():
self.assertRaisesRegex(Exception, ".* is not supported by dlpack",
UnsupportedComplex64)

def testMustPassTensorArgumentToDLPack(self):
with self.assertRaisesRegex(
errors.InvalidArgumentError,
"The argument to `to_dlpack` must be a TF tensor, not Python object"):
dlpack.to_dlpack([1])


if __name__ == "__main__":
ops.enable_eager_execution()
Expand Down
9 changes: 8 additions & 1 deletion tensorflow/python/tfe_wrapper.cc
Expand Up @@ -1358,9 +1358,16 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
// DLPack functions
m.def("TFE_ToDlpackCapsule", [](py::handle& o) {
PyObject* eager_tensor_pyobject_ptr = o.ptr();
TFE_TensorHandle* thandle = EagerTensor_Handle(eager_tensor_pyobject_ptr);
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());

if (!EagerTensor_CheckExact(eager_tensor_pyobject_ptr)) {
status->status = tensorflow::errors::InvalidArgument(
"The argument to `to_dlpack` must be a TF tensor, not Python object");
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
}

TFE_TensorHandle* thandle = EagerTensor_Handle(eager_tensor_pyobject_ptr);
void* dlm_ptr = tensorflow::TFE_HandleToDLPack(thandle, status.get());
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());

Expand Down

0 comments on commit 22e07fb

Please sign in to comment.