Skip to content

Commit

Permalink
Update iterate.py
Browse files Browse the repository at this point in the history
  • Loading branch information
zsdonghao committed Jul 3, 2019
1 parent 11ef84d commit 903186f
Showing 1 changed file with 42 additions and 42 deletions.
84 changes: 42 additions & 42 deletions tensorlayer/iterate.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ def minibatches(inputs=None, targets=None, batch_size=None, allow_dynamic_batch_
>>> y = np.asarray([0,1,2,3,4,5])
>>> for batch in tl.iterate.minibatches(inputs=X, targets=y, batch_size=2, shuffle=False):
>>> print(batch)
(array([['a', 'a'], ['b', 'b']], dtype='<U1'), array([0, 1]))
(array([['c', 'c'], ['d', 'd']], dtype='<U1'), array([2, 3]))
(array([['e', 'e'], ['f', 'f']], dtype='<U1'), array([4, 5]))
... (array([['a', 'a'], ['b', 'b']], dtype='<U1'), array([0, 1]))
... (array([['c', 'c'], ['d', 'd']], dtype='<U1'), array([2, 3]))
... (array([['e', 'e'], ['f', 'f']], dtype='<U1'), array([4, 5]))
Notes
-----
Expand Down Expand Up @@ -97,8 +97,8 @@ def seq_minibatches(inputs, targets, batch_size, seq_length, stride=1):
>>> y = np.asarray([0, 1, 2, 3, 4, 5])
>>> for batch in tl.iterate.seq_minibatches(inputs=X, targets=y, batch_size=2, seq_length=2, stride=1):
>>> print(batch)
(array([['a', 'a'], ['b', 'b'], ['b', 'b'], ['c', 'c']], dtype='<U1'), array([0, 1, 1, 2]))
(array([['c', 'c'], ['d', 'd'], ['d', 'd'], ['e', 'e']], dtype='<U1'), array([2, 3, 3, 4]))
... (array([['a', 'a'], ['b', 'b'], ['b', 'b'], ['c', 'c']], dtype='<U1'), array([0, 1, 1, 2]))
... (array([['c', 'c'], ['d', 'd'], ['d', 'd'], ['e', 'e']], dtype='<U1'), array([2, 3, 3, 4]))
Many to One
Expand All @@ -112,14 +112,14 @@ def seq_minibatches(inputs, targets, batch_size, seq_length, stride=1):
>>> tmp_y = y.reshape((-1, num_steps) + y.shape[1:])
>>> y = tmp_y[:, -1]
>>> print(x, y)
[['a' 'a']
['b' 'b']
['b' 'b']
['c' 'c']] [1 2]
[['c' 'c']
['d' 'd']
['d' 'd']
['e' 'e']] [3 4]
... [['a' 'a']
... ['b' 'b']
... ['b' 'b']
... ['c' 'c']] [1 2]
... [['c' 'c']
... ['d' 'd']
... ['d' 'd']
... ['e' 'e']] [3 4]
"""
if len(inputs) != len(targets):
Expand Down Expand Up @@ -171,21 +171,21 @@ def seq_minibatches2(inputs, targets, batch_size, num_steps):
>>> for batch in tl.iterate.seq_minibatches2(X, Y, batch_size=2, num_steps=3):
... x, y = batch
... print(x, y)
[[ 0. 1. 2.]
[ 10. 11. 12.]]
[[ 20. 21. 22.]
[ 30. 31. 32.]]
[[ 3. 4. 5.]
[ 13. 14. 15.]]
[[ 23. 24. 25.]
[ 33. 34. 35.]]
[[ 6. 7. 8.]
[ 16. 17. 18.]]
[[ 26. 27. 28.]
[ 36. 37. 38.]]
...
... [[ 0. 1. 2.]
... [ 10. 11. 12.]]
... [[ 20. 21. 22.]
... [ 30. 31. 32.]]
...
... [[ 3. 4. 5.]
... [ 13. 14. 15.]]
... [[ 23. 24. 25.]
... [ 33. 34. 35.]]
...
... [[ 6. 7. 8.]
... [ 16. 17. 18.]]
... [[ 26. 27. 28.]
... [ 36. 37. 38.]]
Notes
-----
Expand Down Expand Up @@ -249,20 +249,20 @@ def ptb_iterator(raw_data, batch_size, num_steps):
>>> for batch in tl.iterate.ptb_iterator(train_data, batch_size=2, num_steps=3):
>>> x, y = batch
>>> print(x, y)
[[ 0 1 2] <---x 1st subset/ iteration
[10 11 12]]
[[ 1 2 3] <---y
[11 12 13]]
[[ 3 4 5] <--- 1st batch input 2nd subset/ iteration
[13 14 15]] <--- 2nd batch input
[[ 4 5 6] <--- 1st batch target
[14 15 16]] <--- 2nd batch target
[[ 6 7 8] 3rd subset/ iteration
[16 17 18]]
[[ 7 8 9]
[17 18 19]]
... [[ 0 1 2] <---x 1st subset/ iteration
... [10 11 12]]
... [[ 1 2 3] <---y
... [11 12 13]]
...
... [[ 3 4 5] <--- 1st batch input 2nd subset/ iteration
... [13 14 15]] <--- 2nd batch input
... [[ 4 5 6] <--- 1st batch target
... [14 15 16]] <--- 2nd batch target
...
... [[ 6 7 8] 3rd subset/ iteration
... [16 17 18]]
... [[ 7 8 9]
... [17 18 19]]
"""
raw_data = np.array(raw_data, dtype=np.int32)

Expand Down

0 comments on commit 903186f

Please sign in to comment.