Skip to content

Commit

Permalink
Allow the output of tf.argmax as index type
Browse files Browse the repository at this point in the history
This fix tries to fix the issue raised in tensorflow#8951 where
the following will raise a `TypeError`:
```
a = tf.constant([1, 2, 3], dtype=tf.float32)
b = tf.argmax(a)
tf.Session().run(a[b])

TypeError: Input 'strides' of 'StridedSlice' Op has type int32 that does not match type int64 of argument 'begin'.
```
The reason for the erorr is that, `strides` is added as `append(1)`
without type while `begin` is appended with type.

The mismatch of `strides` and `begin` causes the error.

This fix fixes the issue by cast the stride with the same type
as `begin` when needed.

This issue was raised in tensorflow#8951. It was also raised earlier in
tensorflow#206 (comment)

This fix fixes tensorflow#8951.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
  • Loading branch information
yongtang committed May 2, 2017
1 parent 85827b2 commit 1159eb2
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion tensorflow/python/ops/array_ops.py
Expand Up @@ -470,7 +470,10 @@ def _SliceHelper(tensor, slice_spec, var=None):
else:
begin.append(s)
end.append(s + 1)
strides.append(1)
if isinstance(s, ops.Tensor):
strides.append(constant(1, s.dtype))
else:
strides.append(np.ones_like(s).dtype.type(1))
shrink_axis_mask |= (1 << index)
index += 1

Expand Down

0 comments on commit 1159eb2

Please sign in to comment.