Skip to content

Commit da85585

Browse files
Fix undefined behavior in tf.raw_ops.Switch in eager mode.
PiperOrigin-RevId: 332578058 Change-Id: I9727571d2f21476b10d8aa27c1b7176564b76ac9
1 parent 22e07fb commit da85585

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

Diff for: tensorflow/core/common_runtime/eager/kernel_and_device.cc

+6-1
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,12 @@ Status KernelAndDeviceOp::Run(
308308
if (outputs != nullptr) {
309309
outputs->clear();
310310
for (int i = 0; i < context.num_outputs(); ++i) {
311-
outputs->push_back(Tensor(*context.mutable_output(i)));
311+
const auto* output_tensor = context.mutable_output(i);
312+
if (output_tensor != nullptr) {
313+
outputs->push_back(Tensor(*output_tensor));
314+
} else {
315+
outputs->push_back(Tensor());
316+
}
312317
}
313318
}
314319
return Status::OK();

Diff for: tensorflow/python/kernel_tests/control_flow_ops_py_test.py

+8
Original file line numberDiff line numberDiff line change
@@ -4579,6 +4579,14 @@ def testUInt64SwitchMerge(self):
45794579
result = control_flow_ops.merge([v_f, v_t])
45804580
self.evaluate(result)
45814581

4582+
def testSwitchEagerMode(self):
4583+
if not context.executing_eagerly():
4584+
return
4585+
input_data = [1, 2, 3, 4]
4586+
vf, vt = control_flow_ops.switch(input_data, False)
4587+
self.assertAllEqual(vf, input_data)
4588+
self.assertAllEqual(vt, [])
4589+
45824590
@test_util.run_deprecated_v1
45834591
def testQIntArgAndRet(self):
45844592

0 commit comments

Comments
 (0)