# Pytorch Dataset Class and Custom Dataset objects

In [1]:
import torch
from torch.utils.data import Dataset, IterableDataset, DataLoader
import os

Let’s first mock a simple dataset by creating a `Dataset` of all numbers from 1 to 1000. We'll aptly name this the `NumbersDataset`.

In [2]:
class NumbersDataset(Dataset):
    def __init__(self):
        self.samples = list(range(1, 1001))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]


if __name__ == '__main__':
    dataset = NumbersDataset()
    print(len(dataset))
    print(dataset[100])
    print(dataset[122:361])

1000
101
[123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 320, 

Extend dataset so it can store all whole numbers between an interval low and high

In [3]:
class NumbersDataset(Dataset):
    def __init__(self, low, high):
        self.samples = list(range(low, high))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]


if __name__ == '__main__':
    dataset = NumbersDataset(2821, 8295)
    print(len(dataset))
    print(dataset[100])
    print(dataset[122:361])

5474
2921
[2943, 2944, 2945, 2946, 2947, 2948, 2949, 2950, 2951, 2952, 2953, 2954, 2955, 2956, 2957, 2958, 2959, 2960, 2961, 2962, 2963, 2964, 2965, 2966, 2967, 2968, 2969, 2970, 2971, 2972, 2973, 2974, 2975, 2976, 2977, 2978, 2979, 2980, 2981, 2982, 2983, 2984, 2985, 2986, 2987, 2988, 2989, 2990, 2991, 2992, 2993, 2994, 2995, 2996, 2997, 2998, 2999, 3000, 3001, 3002, 3003, 3004, 3005, 3006, 3007, 3008, 3009, 3010, 3011, 3012, 3013, 3014, 3015, 3016, 3017, 3018, 3019, 3020, 3021, 3022, 3023, 3024, 3025, 3026, 3027, 3028, 3029, 3030, 3031, 3032, 3033, 3034, 3035, 3036, 3037, 3038, 3039, 3040, 3041, 3042, 3043, 3044, 3045, 3046, 3047, 3048, 3049, 3050, 3051, 3052, 3053, 3054, 3055, 3056, 3057, 3058, 3059, 3060, 3061, 3062, 3063, 3064, 3065, 3066, 3067, 3068, 3069, 3070, 3071, 3072, 3073, 3074, 3075, 3076, 3077, 3078, 3079, 3080, 3081, 3082, 3083, 3084, 3085, 3086, 3087, 3088, 3089, 3090, 3091, 3092, 3093, 3094, 3095, 3096, 3097, 3098, 3099, 3100, 3101, 3102, 3103, 3104, 3105, 3106, 3107,

## Test IterableDatasets

In [5]:
from itertools import cycle, islice, chain

data = [0,1,2,3,4,5,6,7,8,9]

### Basic

In [6]:
class MyIterableDataset(IterableDataset):
    
    def __init__(self, data):
        self.data = data
        
    def __iter__(self):
        return iter(self.data)
    
iterable_dataset = MyIterableDataset(data)

loader = DataLoader(iterable_dataset, batch_size=4)

for batch in loader:
    print(batch)

tensor([0, 1, 2, 3])
tensor([4, 5, 6, 7])
tensor([8, 9])


### Iterate through stream of data



In [7]:
class MyIterableDataset(IterableDataset):
    
    def __init__(self, data):
        self.data = data
        
    def process_data(self, data):
        for x in data:
            yield x
    
    def get_stream(self, data):
        return cycle(self.process_data(data))
    
    def __iter__(self):
        return self.get_stream(self.data)
    
iterable_dataset = MyIterableDataset(data)

loader = DataLoader(iterable_dataset, batch_size=4)

for batch in islice(loader, 8):
    print(batch)

tensor([0, 1, 2, 3])
tensor([4, 5, 6, 7])
tensor([8, 9, 0, 1])
tensor([2, 3, 4, 5])
tensor([6, 7, 8, 9])
tensor([0, 1, 2, 3])
tensor([4, 5, 6, 7])
tensor([8, 9, 0, 1])


### Stream data from multiple files

In [8]:
class MyIterableDataset(IterableDataset):
    
    def __init__(self, data_list):
        self.data_list = data_list
        
    def process_data(self, data):
        for x in data:
            yield x
    
    def get_stream(self, data_list):
        return chain.from_iterable(map(self.process_data, cycle(data_list)))
    
    def __iter__(self):
        return self.get_stream(self.data_list)
    
data_list = [
    [12, 13, 14, 15, 16, 17],
    [27, 28, 29],
    [31, 32, 33, 34, 35, 36, 37, 38, 39],
    [40, 41, 42, 43],
]

iterable_dataset = MyIterableDataset(data_list)

loader = DataLoader(iterable_dataset, batch_size=4)

for batch in islice(loader, 8):
    print(batch)

tensor([12, 13, 14, 15])
tensor([16, 17, 27, 28])
tensor([29, 31, 32, 33])
tensor([34, 35, 36, 37])
tensor([38, 39, 40, 41])
tensor([42, 43, 12, 13])
tensor([14, 15, 16, 17])
tensor([27, 28, 29, 31])


### Return batches of sequences

In [9]:
class MyIterableDataset(IterableDataset):
    
    def __init__(self, data_list, batch_size):
        self.data_list = data_list
        self.batch_size = batch_size
        
    def process_data(self, data):
        for x in data:
            yield x
    
    def get_stream(self, data_list):
        return chain.from_iterable(map(self.process_data, cycle(data_list)))
    
    def get_streams(self):
        return zip(*[self.get_stream(self.data_list) for _ in range(self.batch_size)])
    
    def __iter__(self):
        return self.get_streams()
    
iterable_dataset = MyIterableDataset(data_list, batch_size=4)

loader = DataLoader(iterable_dataset, batch_size=None)

for batch in islice(loader, 8):
    print(batch)

[12, 12, 12, 12]
[13, 13, 13, 13]
[14, 14, 14, 14]
[15, 15, 15, 15]
[16, 16, 16, 16]
[17, 17, 17, 17]
[27, 27, 27, 27]
[28, 28, 28, 28]


### Partition data into groups, feed each group into a single stream

In [10]:
import random

class MyIterableDataset(IterableDataset):
    
    def __init__(self, data_list, batch_size):
        self.data_list = data_list
        self.batch_size = batch_size
    
    @property
    def shuffled_data_list(self):
        return random.sample(self.data_list, len(self.data_list))
    
    def process_data(self, data):
        for x in data:
            yield x
    
    def get_stream(self, data_list):
        return chain.from_iterable(map(self.process_data, cycle(data_list)))
    
    def get_streams(self):
        return zip(*[self.get_stream(self.shuffled_data_list) for _ in range(self.batch_size)])
    
    def __iter__(self):
        return self.get_streams()
    
iterable_dataset = MyIterableDataset(data_list, batch_size=4)

loader = DataLoader(iterable_dataset, batch_size=None)

for batch in islice(loader, 12):
    print(batch)

[12, 27, 31, 12]
[13, 28, 32, 13]
[14, 29, 33, 14]
[15, 40, 34, 15]
[16, 41, 35, 16]
[17, 42, 36, 17]
[27, 43, 37, 40]
[28, 12, 38, 41]
[29, 13, 39, 42]
[40, 14, 40, 43]
[41, 15, 41, 27]
[42, 16, 42, 28]


### Parallel distribution

In [11]:
# def worker_init_fn(_):
#     worker_info = torch.utils.data.get_Worker_info()
    
#     dataset = worker_info.dataset
#     worker_id = worker_info.id
#     split_size = len(dataset.data) // worker_info.num_workers
    
#     dataset.data = dataset.data[worker_id * split_size:(worker_id + 1) * split_size]

import time

class MyIterableDataset(IterableDataset):
    
    def __init__(self, data_list, batch_size):
        self.data_list = data_list
        self.batch_size = batch_size
        
    @property
    def shuffled_data_list(self):
        return random.sample(self.data_list, len(self.data_list))
    
    def process_data(self, data):
        for x in data:
            worker = torch.utils.data.get_worker_info()
            worker_id = id(self) if worker is not None else -1
            
            start = time.time()
            time.sleep(0.1)
            end = time.time()
            
            yield x, worker_id #, start, end
            
    def get_stream(self, data_list):
        return chain.from_iterable(map(self.process_data, cycle(data_list)))
    
    def get_streams(self):
        return zip(*[self.get_stream(self.shuffled_data_list) for _ in range(self.batch_size)])
        
    def __iter__(self):
        return self.get_streams()
    
    @classmethod
    def split_datasets(cls, data_list, batch_size, max_workers):
        
        for n in range(max_workers, 0, -1):
            if batch_size % n == 0:
                num_workers = n
                break
                
        split_size = batch_size // num_workers
        
        return [cls(data_list, batch_size=split_size) for _ in range(num_workers)]
    
class MultiStreamDataLoader:
    
    def __init__(self, datasets):
        self.datasets = datasets
        
    def get_stream_loaders(self):
        return zip(*[DataLoader(dataset, num_workers=1, batch_size=None) for dataset in datasets])
    
    def __iter__(self):
        for batch_parts in self.get_stream_loaders():
            yield list(chain(*batch_parts))
                       

In [12]:
datasets = MyIterableDataset.split_datasets(data_list, batch_size=4, max_workers=1)

loader = MultiStreamDataLoader(datasets) 

start = time.time()
for batch in islice(loader, 12):
    print(batch)
end = time.time()
print("time:", end-start)

[[12, 140394551490640], [31, 140394551490640], [27, 140394551490640], [40, 140394551490640]]
[[13, 140394551490640], [32, 140394551490640], [28, 140394551490640], [41, 140394551490640]]
[[14, 140394551490640], [33, 140394551490640], [29, 140394551490640], [42, 140394551490640]]
[[15, 140394551490640], [34, 140394551490640], [12, 140394551490640], [43, 140394551490640]]
[[16, 140394551490640], [35, 140394551490640], [13, 140394551490640], [12, 140394551490640]]
[[17, 140394551490640], [36, 140394551490640], [14, 140394551490640], [13, 140394551490640]]
[[27, 140394551490640], [37, 140394551490640], [15, 140394551490640], [14, 140394551490640]]
[[28, 140394551490640], [38, 140394551490640], [16, 140394551490640], [15, 140394551490640]]
[[29, 140394551490640], [39, 140394551490640], [17, 140394551490640], [16, 140394551490640]]
[[31, 140394551490640], [12, 140394551490640], [40, 140394551490640], [17, 140394551490640]]
[[32, 140394551490640], [13, 140394551490640], [41, 140394551490640], 

Increase workers

In [13]:
datasets = MyIterableDataset.split_datasets(data_list, batch_size=4, max_workers=2)

loader = MultiStreamDataLoader(datasets) 

start = time.time()
for batch in islice(loader, 12):
    print(batch)
end = time.time()
print("time:", end-start)

[[27, 140394551489296], [31, 140394551489296], [12, 140394551488720], [12, 140394551488720]]
[[28, 140394551489296], [32, 140394551489296], [13, 140394551488720], [13, 140394551488720]]
[[29, 140394551489296], [33, 140394551489296], [14, 140394551488720], [14, 140394551488720]]
[[12, 140394551489296], [34, 140394551489296], [15, 140394551488720], [15, 140394551488720]]
[[13, 140394551489296], [35, 140394551489296], [16, 140394551488720], [16, 140394551488720]]
[[14, 140394551489296], [36, 140394551489296], [17, 140394551488720], [17, 140394551488720]]
[[15, 140394551489296], [37, 140394551489296], [40, 140394551488720], [27, 140394551488720]]
[[16, 140394551489296], [38, 140394551489296], [41, 140394551488720], [28, 140394551488720]]
[[17, 140394551489296], [39, 140394551489296], [42, 140394551488720], [29, 140394551488720]]
[[31, 140394551489296], [12, 140394551489296], [43, 140394551488720], [31, 140394551488720]]
[[32, 140394551489296], [13, 140394551489296], [27, 140394551488720], 

In [14]:
datasets = MyIterableDataset.split_datasets(data_list, batch_size=128, max_workers=5)
datasets

[<__main__.MyIterableDataset at 0x7fb026f63e90>,
 <__main__.MyIterableDataset at 0x7fb026f63910>,
 <__main__.MyIterableDataset at 0x7fb026f63890>,
 <__main__.MyIterableDataset at 0x7fb026f63bd0>]

Increase workers and batch size

In [15]:
datasets = MyIterableDataset.split_datasets(data_list, batch_size=4, max_workers=4)

loader = MultiStreamDataLoader(datasets) 

start = time.time()
for batch in islice(loader, 12):
    print(batch)
end = time.time()
print("time:", end-start)

[[31, 140394544730384], [40, 140394544730448], [40, 140394544730512], [31, 140394544730576]]
[[32, 140394544730384], [41, 140394544730448], [41, 140394544730512], [32, 140394544730576]]
[[33, 140394544730384], [42, 140394544730448], [42, 140394544730512], [33, 140394544730576]]
[[34, 140394544730384], [43, 140394544730448], [43, 140394544730512], [34, 140394544730576]]
[[35, 140394544730384], [31, 140394544730448], [12, 140394544730512], [35, 140394544730576]]
[[36, 140394544730384], [32, 140394544730448], [13, 140394544730512], [36, 140394544730576]]
[[37, 140394544730384], [33, 140394544730448], [14, 140394544730512], [37, 140394544730576]]
[[38, 140394544730384], [34, 140394544730448], [15, 140394544730512], [38, 140394544730576]]
[[39, 140394544730384], [35, 140394544730448], [16, 140394544730512], [39, 140394544730576]]
[[12, 140394544730384], [36, 140394544730448], [17, 140394544730512], [27, 140394544730576]]
[[13, 140394544730384], [37, 140394544730448], [27, 140394544730512], 

In [16]:
datasets = MyIterableDataset.split_datasets(data_list, batch_size=4, max_workers=6)
# worker_dataset = datasets[worker_count]
loader = MultiStreamDataLoader(datasets) 

start = time.time()
for batch in islice(loader, 12):
    print(batch)
end = time.time()
print("time:", end-start)

[[40, 140394544640656], [40, 140394544640336], [27, 140394544640592], [27, 140394544642640]]
[[41, 140394544640656], [41, 140394544640336], [28, 140394544640592], [28, 140394544642640]]
[[42, 140394544640656], [42, 140394544640336], [29, 140394544640592], [29, 140394544642640]]
[[43, 140394544640656], [43, 140394544640336], [31, 140394544640592], [12, 140394544642640]]
[[31, 140394544640656], [31, 140394544640336], [32, 140394544640592], [13, 140394544642640]]
[[32, 140394544640656], [32, 140394544640336], [33, 140394544640592], [14, 140394544642640]]
[[33, 140394544640656], [33, 140394544640336], [34, 140394544640592], [15, 140394544642640]]
[[34, 140394544640656], [34, 140394544640336], [35, 140394544640592], [16, 140394544642640]]
[[35, 140394544640656], [35, 140394544640336], [36, 140394544640592], [17, 140394544642640]]
[[36, 140394544640656], [36, 140394544640336], [37, 140394544640592], [40, 140394544642640]]
[[37, 140394544640656], [37, 140394544640336], [38, 140394544640592], 