Skip to content

Commit

Permalink
fix keras-team#8087, K.arange accept tensor input
Browse files Browse the repository at this point in the history
  • Loading branch information
pstjohn committed Dec 16, 2017
1 parent 0d66dc4 commit 9c87163
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 2 deletions.
9 changes: 7 additions & 2 deletions keras/backend/tensorflow_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2055,8 +2055,13 @@ def arange(start, stop=None, step=1, dtype='int32'):
"""
# Match the behavior of numpy and Theano by returning an empty seqence.
if stop is None and start < 0:
start = 0
if stop is None:
try:
if start < 0:
start = 0
except TypeError:
start = tf.cond(start < 0, lambda: 0, lambda: start)

result = tf.range(start, limit=stop, delta=step, name='arange')
if dtype != 'int32':
result = cast(result, dtype)
Expand Down
9 changes: 9 additions & 0 deletions tests/keras/backend/backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1371,6 +1371,15 @@ def test_arange(self):
t = backend.arange(10, dtype=dtype)
assert backend.dtype(t) == dtype

for backend in [KTH, KTF]:
start = k.constant(1)
t = backend.arange(start)
assert len(k.eval(t)) == 1

start = k.constant(-1)
t = backend.arange(start)
assert len(k.eval(t)) == 0

def test_in_train_phase(self):
for training in [True, False]:
check_two_tensor_operation('in_train_phase', (3, 3), (2, 2), [KTH, KTF],
Expand Down

0 comments on commit 9c87163

Please sign in to comment.