# Sampler

torch.utils.data.Sampler

Base class for all Samplers.
Every Sampler subclass has to provide an __iter__() method, providing a way to iterate over indices of dataset elements, and a __len__() method that returns the length of the returned iterators

In [1]:
import torch

In [2]:
data_source = [2,5,7,11,22,33,1,3,4,55]

## SequentialSampler

In [3]:
# ================================================= #
# sampler                                           #
# ================================================= #
sampler1 = torch.utils.data.SequentialSampler(data_source=data_source)
print(list(sampler1))

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]


In [4]:
sampler_iterater = iter(sampler1)

In [5]:
#next(sampler_iterater)
#next(sampler_iterater)

In [6]:
while True:
    try:
        idx = next(sampler_iterater)
        print(idx)
        print(data_source[idx])
        print("----------")
    except:
        break

0
2
----------
1
5
----------
2
7
----------
3
11
----------
4
22
----------
5
33
----------
6
1
----------
7
3
----------
8
4
----------
9
55
----------


## RandomSampler

In [7]:
sampler2 = torch.utils.data.RandomSampler(data_source=data_source, replacement=False, num_samples=None)
#print(list(sampler2))
print(len(sampler2))

10


In [8]:
sampler_iterater = iter(sampler2)

In [9]:
while True:
    try:
        idx = next(sampler_iterater)
        print(idx)
        print(data_source[idx])
        print("----------")
    except:
        break

0
2
----------
5
33
----------
4
22
----------
9
55
----------
1
5
----------
6
1
----------
8
4
----------
2
7
----------
3
11
----------
7
3
----------


In [10]:
sampler3 = torch.utils.data.RandomSampler(data_source=data_source, replacement=False)
#print(list(sampler2))
print(len(sampler3))

10


In [13]:
sampler_iterater = iter(sampler3)

In [14]:
while True:
    try:
        idx = next(sampler_iterater)
        print(idx)
        print(data_source[idx])
        print("----------")
    except:
        break

8
4
----------
4
22
----------
9
55
----------
5
33
----------
7
3
----------
2
7
----------
1
5
----------
3
11
----------
6
1
----------
0
2
----------


In [11]:
sampler4 = torch.utils.data.RandomSampler(data_source=data_source, replacement=True, num_samples=5)
#print(list(sampler4))
print(len(sampler4))

5


In [15]:
sampler_iterater = iter(sampler4)

In [16]:
while True:
    try:
        idx = next(sampler_iterater)
        print(idx)
        print(data_source[idx])
        print("----------")
    except:
        break

3
11
----------
6
1
----------
2
7
----------
3
11
----------
5
33
----------


# batch_sampler 

In [17]:
# ================================================== #
# batch_sampler                                      #
# ================================================== #
batch_sampler = torch.utils.data.BatchSampler(sampler2, batch_size=3, drop_last=False)
print(list(batch_sampler))

[[9, 6, 4], [2, 5, 8], [3, 0, 1], [7]]


In [18]:
batch_sampler = torch.utils.data.BatchSampler(sampler2, batch_size=3, drop_last=True)
print(list(batch_sampler))

[[7, 2, 0], [4, 1, 8], [5, 9, 3]]
