Skip to content
Permalink
Browse files Browse the repository at this point in the history
Fix pywrap attribute read security vulnerability.
If a list of quantized tensors is assigned to an attribute, the pywrap code was failing to
parse the tensor and returning a `nullptr`, which wasn't caught.  Here we check the return
value and set an appropriate error status.

PiperOrigin-RevId: 476981029
  • Loading branch information
cantonios authored and tensorflower-gardener committed Sep 26, 2022
1 parent e7ed22e commit e9e9555
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 5 deletions.
19 changes: 14 additions & 5 deletions tensorflow/python/eager/pywrap_tfe_src.cc
Expand Up @@ -397,11 +397,20 @@ bool SetOpAttrList(TFE_Context* ctx, TFE_Op* op, const char* key,
const int num_values = PySequence_Size(py_list);
if (attr_list_sizes != nullptr) (*attr_list_sizes)[key] = num_values;

#define PARSE_LIST(c_type, parse_fn) \
std::unique_ptr<c_type[]> values(new c_type[num_values]); \
for (int i = 0; i < num_values; ++i) { \
tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i)); \
if (!parse_fn(key, py_value.get(), status, &values[i])) return false; \
#define PARSE_LIST(c_type, parse_fn) \
std::unique_ptr<c_type[]> values(new c_type[num_values]); \
for (int i = 0; i < num_values; ++i) { \
tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i)); \
if (py_value == nullptr) { \
TF_SetStatus(status, TF_INVALID_ARGUMENT, \
tensorflow::strings::StrCat( \
"Expecting sequence of " #c_type " for attr ", key, \
", got ", py_list->ob_type->tp_name) \
.c_str()); \
return false; \
} else if (!parse_fn(key, py_value.get(), status, &values[i])) { \
return false; \
} \
}

if (type == TF_ATTR_STRING) {
Expand Down
Expand Up @@ -17,7 +17,9 @@
import numpy as np

from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test


Expand Down Expand Up @@ -139,6 +141,17 @@ def testComplexDataTypes(self):
padding=padding,
patches=patches)

def testInvalidAttributes(self):
"""Test for passing weird things into ksizes."""
with self.assertRaisesRegex(TypeError, "Expected list"):
image = constant_op.constant([0.0])
ksizes = math_ops.cast(
constant_op.constant(dtype=dtypes.int16, value=[[1, 4], [5, 2]]),
dtype=dtypes.qint16)
strides = [1, 1, 1, 1]
self.evaluate(
array_ops.extract_image_patches(
image, ksizes=ksizes, strides=strides, padding="SAME"))

if __name__ == "__main__":
test.main()

0 comments on commit e9e9555

Please sign in to comment.