Skip to content

Commit

Permalink
Merge pull request #9 from SrMouraSilva/issue-7-python3-batch
Browse files Browse the repository at this point in the history
Fix #7 Minibatches in python3
  • Loading branch information
yell committed Mar 6, 2019
2 parents 93ece34 + 764ca34 commit 8a2efc0
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions boltzmann_machines/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def batch_iter(X, batch_size=10, verbose=False, desc='epoch'):
--------
>>> X = np.arange(36).reshape((12, 3))
>>> for X_b in batch_iter(X, batch_size=5):
... print X_b
... print(X_b)
[[ 0 1 2]
[ 3 4 5]
[ 6 7 8]
Expand All @@ -34,15 +34,17 @@ def batch_iter(X, batch_size=10, verbose=False, desc='epoch'):
"""
X = np.asarray(X)
N = len(X)
n_batches = N / batch_size + (N % batch_size > 0)
n_batches = N // batch_size + (N % batch_size > 0)
gen = range(n_batches)
if verbose: gen = progress_bar(gen, leave=False, ncols=64, desc=desc)
if verbose:
gen = progress_bar(gen, leave=False, ncols=64, desc=desc)
for i in gen:
yield X[i*batch_size:(i + 1)*batch_size]

def epoch_iter(start_epoch, max_epoch, verbose=False):
gen = range(start_epoch + 1, max_epoch + 1)
if verbose: gen = progress_bar(gen, leave=True, ncols=84, desc='training')
if verbose:
gen = progress_bar(gen, leave=True, ncols=84, desc='training')
for epoch in gen:
yield epoch

Expand Down

0 comments on commit 8a2efc0

Please sign in to comment.