Skip to content
Permalink
Browse files Browse the repository at this point in the history
Fix TensorKey hash function.
The original hash function only used total estimated `AllocatedBytes()`, which (a) is an estimate per tensor, and (b) is a very poor hash function for constants (e.g. `int32_t`).
It also tried to access individual tensor bytes through `tensor.data()` of size `AllocatedBytes()`.  This led to ASAN failures because the `AllocatedBytes()` is an estimate of total bytes allocated by a tensor, including any pointed-to constructs (e.g. strings), and does not refer to contiguous bytes in the `.data()` buffer.  We couldn't use this byte vector anyways, since types like `tstring` include pointers, whereas we need to hash the string values themselves.

Modified the hash function to more closely mirror the `==` operator.  This correctly handles `tstring` and any numeric types that do have contiguous storage.  Other types are currently left as unimplemented.

PiperOrigin-RevId: 446265413
  • Loading branch information
cantonios authored and tensorflower-gardener committed May 3, 2022
1 parent 6967b1f commit 1b85a28
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 11 deletions.
26 changes: 17 additions & 9 deletions tensorflow/core/framework/tensor_key.h
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
#define TENSORFLOW_CORE_FRAMEWORK_TENSOR_KEY_H_

#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"

namespace tensorflow {

Expand All @@ -32,8 +33,7 @@ class TensorKey : public Tensor {
}
if (DataTypeCanUseMemcpy(t1.dtype())) {
return t1.tensor_data() == t2.tensor_data();
}
if (t1.dtype() == DT_STRING) {
} else if (t1.dtype() == DT_STRING) {
const auto s1 = t1.unaligned_flat<tstring>();
const auto s2 = t2.unaligned_flat<tstring>();
for (int64_t i = 0, n = t1.NumElements(); i < n; ++i) {
Expand All @@ -42,6 +42,9 @@ class TensorKey : public Tensor {
}
}
return true;
} else {
DCHECK(false) << "Unimplemented dtype " << DataTypeString(t1.dtype())
<< std::endl;
}
return false;
}
Expand All @@ -53,14 +56,19 @@ class TensorKey : public Tensor {
// Needed for absl hash function.
template <typename H>
friend H AbslHashValue(H h, const TensorKey& k) {
const uint8* d = static_cast<uint8*>(k.data());

This comment was marked as spam.

Copy link
@onixpunki

onixpunki Sep 1, 2022

g

size_t s = k.AllocatedBytes();

This comment was marked as spam.

Copy link
@onixpunki

onixpunki Sep 1, 2022

2m

std::vector<uint8> vec;
vec.reserve(s);
for (int i = 0; i < s; i++) {
vec.push_back(d[i]);
if (DataTypeCanUseMemcpy(k.dtype())) {
return H::combine(std::move(h), k.tensor_data());
} else if (k.dtype() == DT_STRING) {
const auto strs = k.unaligned_flat<tstring>();
for (int64_t i = 0, n = k.NumElements(); i < n; ++i) {
h = H::combine(std::move(h), strs(i));
}
return h;
} else {
DCHECK(false) << "Unimplemented dtype " << DataTypeString(k.dtype())
<< std::endl;
}
return H::combine(std::move(h), s);
return h;
}
};

Expand Down
2 changes: 0 additions & 2 deletions tensorflow/python/kernel_tests/data_structures/BUILD
Expand Up @@ -165,8 +165,6 @@ tf_py_test(
grpc_enabled = True,
tags = [
"no_windows", # TODO(b/192259628)
"noasan", # TODO(b/164696004)
"notsan", # TODO(b/164696004)
],
deps = [
"//tensorflow/python:array_ops",
Expand Down

2 comments on commit 1b85a28

@onixpunki

This comment was marked as spam.

@onixpunki

This comment was marked as spam.

Please sign in to comment.