Skip to content

Commit

Permalink
Replace CHECKs in v1 checkpoint loading codepath with returning errors.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 167392822
  • Loading branch information
tensorflower-gardener committed Sep 2, 2017
1 parent 7d5cbd7 commit ddba1e0
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 9 deletions.
9 changes: 6 additions & 3 deletions tensorflow/core/kernels/save_restore_tensor.cc
Expand Up @@ -216,9 +216,12 @@ void RestoreTensor(OpKernelContext* context,

if (output_shape.num_elements() == 0) return;

#define READER_COPY(T) \
case DataTypeToEnum<T>::value: \
reader->CopySliceData(tensor_name, slice_to_load, t->flat<T>().data()); \
#define READER_COPY(T) \
case DataTypeToEnum<T>::value: \
OP_REQUIRES(context, \
reader->CopySliceData(tensor_name, slice_to_load, \
t->flat<T>().data()), \
errors::InvalidArgument("Error copying slice data")); \
break;

switch (type) {
Expand Down
17 changes: 11 additions & 6 deletions tensorflow/core/util/tensor_slice_reader.h
Expand Up @@ -165,13 +165,18 @@ bool TensorSliceReader::CopySliceData(const string& name,
CHECK_GE(idx, 0) << "Failed to find the index for filename " << fname;
// We read a record in the corresponding sstable
const string key = EncodeTensorNameSlice(name, slice_s);
CHECK(sss_[idx]->Get(key, &value))
<< "Failed to seek to the record for tensor " << name << ", slice "
<< slice_s.DebugString() << ": computed key = " << key;
if (!sss_[idx]->Get(key, &value)) {
VLOG(1) << "Failed to seek to the record for tensor " << name
<< ", slice " << slice_s.DebugString()
<< ": computed key = " << key;
return false;
}
SavedTensorSlices sts;
CHECK(ParseProtoUnlimited(&sts, value))
<< "Failed to parse the record for tensor " << name << ", slice "
<< slice_s.DebugString() << ": computed key = " << key;
if (!ParseProtoUnlimited(&sts, value)) {
VLOG(1) << "Failed to parse the record for tensor " << name << ", slice "
<< slice_s.DebugString() << ": computed key = " << key;
return false;
}
CopyDataFromTensorSliceToTensorSlice(
tss->shape(), slice_s, slice,
checkpoint::TensorProtoData<T>(sts.data().data()), data);
Expand Down

0 comments on commit ddba1e0

Please sign in to comment.