Skip to content

Commit

Permalink
Merge pull request #52257 from wamuir:go-validate-newtensor
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 401774163
Change-Id: I17735dd5241fa9115e46c54335826979072faa1f
  • Loading branch information
tensorflower-gardener committed Oct 8, 2021
2 parents aff4f3c + aa700a8 commit c36d5a1
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 19 deletions.
23 changes: 12 additions & 11 deletions tensorflow/go/tensor.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,7 @@ func NewTensor(value interface{}) (*Tensor, error) {

raw := tensorData(t.c)

defer runtime.SetFinalizer(t, func(t *Tensor) {
t.finalize()
})
runtime.SetFinalizer(t, (*Tensor).finalize)

buf := bytes.NewBuffer(raw[:0:len(raw)])

Expand All @@ -115,10 +113,7 @@ func NewTensor(value interface{}) (*Tensor, error) {
// not be contiguous with the others or in the order we might
// expect, so we need to work our way down to each slice of
// primitives and copy them individually
if n, err := encodeTensorWithSlices(buf, val, shape); err != nil {
// Set nbytes to count of bytes written for deferred call to
// runtime.SetFinalizer
nbytes = uintptr(n)
if _, err := encodeTensorWithSlices(buf, val, shape); err != nil {
return nil, err
}
}
Expand Down Expand Up @@ -422,11 +417,17 @@ func shapeAndDataTypeOf(val reflect.Value) (shape []int64, dt DataType, err erro
typ := val.Type()
for typ.Kind() == reflect.Array || typ.Kind() == reflect.Slice {
shape = append(shape, int64(val.Len()))
// If slice elements are slices, verify that all of them have the same size.
// Go's type system makes that guarantee for arrays.
if val.Len() > 0 {
// In order to check tensor structure properly in general case we need to iterate over all slices of the tensor to check sizes match
// Since we already going to iterate over all elements in encodeTensor() let's
// 1) do the actual check in encodeTensor() to save some cpu cycles here
// 2) assume the shape is represented by lengths of elements with zero index in each dimension
if val.Type().Elem().Kind() == reflect.Slice {
expected := val.Index(0).Len()
for i := 1; i < val.Len(); i++ {
if val.Index(i).Len() != expected {
return shape, dt, fmt.Errorf("mismatched slice lengths: %d and %d", val.Index(i).Len(), expected)
}
}
}
val = val.Index(0)
}
typ = typ.Elem()
Expand Down
46 changes: 38 additions & 8 deletions tensorflow/go/tensor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"fmt"
"io"
"reflect"
"runtime"
"testing"
)

Expand Down Expand Up @@ -77,14 +78,6 @@ func TestNewTensor(t *testing.T) {
// native ints not supported
int(5),
[]int{5},
// Mismatched dimensions
[][]float32{{1, 2, 3}, {4}},
// Mismatched dimensions. Should return "mismatched slice lengths" error instead of "BUG"
[][][]float32{{{1, 2}, {3, 4}}, {{1}, {3}}},
// Mismatched dimensions. Should return error instead of valid tensor
[][][]float32{{{1, 2}, {3, 4}}, {{1}, {3}}, {{1, 2, 3}, {2, 3, 4}}},
// Mismatched dimensions for strings
[][]string{{"abc"}, {"abcd", "abcd"}},
}

for _, test := range tests {
Expand Down Expand Up @@ -117,6 +110,43 @@ func TestNewTensor(t *testing.T) {
}
}

func TestNewTensorValidateDimensions(t *testing.T) {
var errorTests = []interface{}{
// Mismatched dimensions
[][]float32{{1, 2, 3}, {4}},
// Mismatched dimensions. Should return "mismatched slice lengths" error instead of "BUG"
[][][]float32{{{1, 2}, {3, 4}}, {{1}, {3}}},
// Mismatched dimensions. Should return error instead of valid tensor
[][][]float32{{{1, 2}, {3, 4}}, {{1}, {3}}, {{1, 2, 3}, {2, 3, 4}}},
// Mismatched dimensions for strings
[][]string{{"abc"}, {"abcd", "abcd"}},
}

// Test that an error is returned in response to mismatched dimensions
// and that no tensor is returned. Dimensions should be checked and a
// mismatch caught in NewTensor prior to actually allocating a new
// tensor in cgo. Given how string tensors are encoded and how tensors
// are freed, a mismatch caught partway through encoding a string
// tensor may result in a segfault, once the finalizer is called. A
// single run of this test is not reliable at producing a segfault,
// hence iteration. See github.com/tensorflow/tensorflow/pull/52257
// for some detail on the issue.
for i := 0; i < 1e5; i++ {
for _, test := range errorTests {
tensor, err := NewTensor(test)
if err == nil {
t.Errorf("NewTensor(%v): %v", test, err)
}
if tensor != nil {
t.Errorf("NewTensor(%v) = %v, want nil", test, tensor)
}
}
}

// Execute any finalizers (blocking).
runtime.GC()
}

func TestTensorSerialization(t *testing.T) {
var tests = []interface{}{
bool(true),
Expand Down

0 comments on commit c36d5a1

Please sign in to comment.