<a href="https://colab.research.google.com/github/ratulb/deep_drive/blob/main/batch_iteration.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# The correct way  to iterate dataset samples with batch

Given a dataset containing $\text{num_samples}$ of rows, if we want to process the samples in batches of size $\text{batch_size}$, we have 3 cases to handle:

1.   $\text{batch_size} < \text{num_samples}$
2.   $\text{batch_size} = \text{num_samples}$
3.   $\text{batch_size} > \text{num_samples}$

For case 1 and 2, we want to make sure that we loop only once.


In [1]:
print('num_samples < batch_size')
num_samples = 19
batch_size = 20
for i in range(num_samples // batch_size):
    start = i * batch_size
    end = start + batch_size
    print(start, end)


print('\nnum_samples == batch_size')
num_samples = 20
batch_size = 20
for i in range(num_samples // batch_size):
    start = i * batch_size
    end = start + batch_size
    print(start, end)

print('\nnum_samples > batch_size')
num_samples = 21
batch_size = 20
for i in range(num_samples // batch_size):
    start = i * batch_size
    end = start + batch_size
    print(start, end)

num_samples < batch_size

num_samples == batch_size
0 20

num_samples > batch_size
0 20


###We see that with 'for i in range (num_samples // batch_size)' - fails case 1 and misses 1 iteration(which should include remaining samples) when num_samples is greater than batch_size.


In [2]:
print('num_samples < batch_size')
num_samples = 19
batch_size = 20
for i in range(num_samples // batch_size + 1):
    start = i * batch_size
    end = start + batch_size
    print(start, end)

print('\nnum_samples == batch_size')
num_samples = 20
batch_size = 20
for i in range(num_samples // batch_size + 1):
    start = i * batch_size
    end = start + batch_size
    print(start, end)

print('\nnum_samples > batch_size')
num_samples = 21
batch_size = 20
for i in range(num_samples // batch_size + 1):
    start = i * batch_size
    end = start + batch_size
    print(start, end)

num_samples < batch_size
0 20

num_samples == batch_size
0 20
20 40

num_samples > batch_size
0 20
20 40


### Changing it to 'for i in range (num_samples // batch_size + 1)' - leads to an extra iteration when num_samples == batch_size.

###Let's change it to 'for i in range ((num_samples - 1) // batch_size + 1)' - see the results.

In [3]:
print('num_samples < batch_size')
num_samples = 19
batch_size = 20
for i in range((num_samples -1) // batch_size + 1):
    start = i * batch_size
    end = start + batch_size
    print(start, end)

print('\nnum_samples == batch_size')
num_samples = 20
batch_size = 20
for i in range((num_samples - 1) // batch_size + 1):
    start = i * batch_size
    end = start + batch_size
    print(start, end)

print('\nnum_samples > batch_size')
num_samples = 21
batch_size = 20
for i in range((num_samples -1) // batch_size + 1):
    start = i * batch_size
    end = start + batch_size
    print(start, end)

num_samples < batch_size
0 20

num_samples == batch_size
0 20

num_samples > batch_size
0 20
20 40


## &#x2705;We see that 'for i in range((num_samples -1) // batch_size + 1)' handles all the 3 cases correctly!