In [1]:
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler

In [2]:
np.__version__

'1.19.1'

In [3]:
class SampleDataset(Dataset):
    """Sample dataset."""

    def __init__(self):
        test_data = {}
        for i in range(100):
            test_data[str(i)] = i
        self.data = test_data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        sample = self.data[str(idx)]
        return sample

In [4]:
TestSample = SampleDataset()
len(TestSample)

100

In [5]:
np.random.seed(1004)
SAMPLES_TO_PLOT = 50
sampled_indices = np.random.choice(len(TestSample), size = SAMPLES_TO_PLOT, replace = False)
print(sampled_indices)

[38  3  8 81 89 91  5 39 32 74 49 68 15 56 21 28 46  4 55 96 77 75 25 30
 73  1 40 79 93 24 36  6 35 99 83 54 82 45 44 33 65 78 13 80  7 26 59 42
 29 27]


In [6]:
a = [38,3, 8, 81, 89, 91,  5, 39, 32, 74, 49, 68, 15, 56, 21, 28, 46,  4, 55, 96, 77, 75, 25, 30,
 73,  1, 40, 79, 93, 24, 36, 6, 35, 99, 83, 54, 82, 45, 44, 33, 65, 78, 13, 80, 7, 26, 59, 42,
 29, 27]

In [7]:
t = SubsetRandomSampler(sampled_indices)
print(t.indices)

[38  3  8 81 89 91  5 39 32 74 49 68 15 56 21 28 46  4 55 96 77 75 25 30
 73  1 40 79 93 24 36  6 35 99 83 54 82 45 44 33 65 78 13 80  7 26 59 42
 29 27]


In [8]:
temp = DataLoader(TestSample, sampler=t, batch_size=5, shuffle=False)
print(type(temp))

<class 'torch.utils.data.dataloader.DataLoader'>


In [9]:
tr_it = iter(temp)
print(len(temp))
ind_list = []
for i in range(len(temp)):
    data = next(tr_it) 
    print(i, data)
    ind_list.extend(data.numpy().tolist())

10
0 tensor([96, 42, 80, 13, 93])
1 tensor([ 1, 78, 82, 65, 99])
2 tensor([ 7, 55, 27, 15, 32])
3 tensor([56, 29, 28, 24, 68])
4 tensor([59,  6, 83, 91, 39])
5 tensor([77, 38,  4, 81, 25])
6 tensor([44, 79, 30,  8, 26])
7 tensor([73, 36, 33,  3,  5])
8 tensor([74, 40, 45, 75, 49])
9 tensor([89, 46, 21, 54, 35])


In [10]:
len(ind_list)

50

In [11]:
set(ind_list) - set(a)

set()

In [12]:
ind_list.sort()
#ind_list
np.array(ind_list)
#[x for x in ind_list]

array([ 1,  3,  4,  5,  6,  7,  8, 13, 15, 21, 24, 25, 26, 27, 28, 29, 30,
       32, 33, 35, 36, 38, 39, 40, 42, 44, 45, 46, 49, 54, 55, 56, 59, 65,
       68, 73, 74, 75, 77, 78, 79, 80, 81, 82, 83, 89, 91, 93, 96, 99])

In [13]:
a.sort()
np.array(a)

array([ 1,  3,  4,  5,  6,  7,  8, 13, 15, 21, 24, 25, 26, 27, 28, 29, 30,
       32, 33, 35, 36, 38, 39, 40, 42, 44, 45, 46, 49, 54, 55, 56, 59, 65,
       68, 73, 74, 75, 77, 78, 79, 80, 81, 82, 83, 89, 91, 93, 96, 99])

## Reference

1. https://pytorch.org/tutorials/beginner/data_loading_tutorial.html
    
2. https://pytorch.org/docs/stable/data.html#data-loading-order-and-sampler