https://medium.com/speechmatics/how-to-build-a-streaming-dataloader-with-pytorch-a66dd891d9dd

In [1]:
import torch
from torch import nn
from torch.nn import functional as F

## Example 1: For map-style data

In [2]:
from torch.utils.data import Dataset, IterableDataset, DataLoader

In [3]:
class MyMapDataset(Dataset):
    
    def __init__(self, data):
        self.data = data
        
    def __getitem__(self, idx):
        return self.data[idx]
    
    def __len__(self):
        return len(self.data)

In [6]:
data = list(range(12))

map_dataset = MyMapDataset(data)
loader = DataLoader(dataset=map_dataset, batch_size=4)

for batch in loader:
    print(batch)

tensor([0, 1, 2, 3])
tensor([4, 5, 6, 7])
tensor([ 8,  9, 10, 11])


## Example 2: For iterable data

In [7]:
class MyIterableDataset(IterableDataset):
    
    def __init__(self, data):
        self.data = data
        
    def __iter__(self):
        return iter(self.data)

In [8]:
data = list(range(12))

iterable_dataset = MyIterableDataset(data)
loader = DataLoader(dataset=iterable_dataset, batch_size=4)

for batch in loader:
    print(batch)

tensor([0, 1, 2, 3])
tensor([4, 5, 6, 7])
tensor([ 8,  9, 10, 11])


In [11]:
from itertools import cycle, islice

class MyIterableDataset(IterableDataset):
    
    def __init__(self, file_path):
        self.file_path = file_path
        
    def parse_file(self, file_path):
        with open(file_path, 'r') as file_obj:
            for line in file_obj:
                tokens = line.strip('\n').split(' ')
                yield from tokens
    
    def get_stream(self, file_path):
        return cycle(self.parse_file(file_path))
    
    def __iter__(self):
        return self.get_stream(self.file_path)

In [13]:
iterable_dataset = MyIterableDataset('file.txt')
loader = DataLoader(dataset=iterable_dataset, batch_size=5)

for batch in islice(loader, 8):
    print(batch)

['Far', 'out', 'in', 'the', 'uncharted']
['backwaters', 'of', 'the', 'unfashionable', 'end']
['of', 'the', 'western', 'spiral', 'arm']
['of', 'the', 'galaxy', 'lies', 'a']
['small', 'unregarded', 'yellow', 'sun.', 'Far']
['out', 'in', 'the', 'uncharted', 'backwaters']
['of', 'the', 'unfashionable', 'end', 'of']
['the', 'western', 'spiral', 'arm', 'of']


In [17]:
from itertools import chain
import random


class MyIterableDataset(IterableDataset):
    
    def __init__(self, data_list, batch_size):
        self.data_list = data_list
        self.batch_size = batch_size
        
    def process_data(self, data):
        for x in data:
            yield x
            
    @property
    def shuffled_data_list(self):
        return random.sample(self.data_list, len(self.data_list))
    
    def get_stream(self, data_list):
        return chain.from_iterable(map(self.process_data,
                                       cycle(data_list)))
    
    def get_streams(self):
        return zip(*[self.get_stream(self.shuffled_data_list) for _ in
                     range(self.batch_size)])
    
    def __iter__(self):
        return self.get_streams()

In [21]:
data_list = [[12, 13, 14, 15, 16, 17],
 [27, 28, 29],
 [31, 32, 33, 34, 35, 36, 37, 38, 39],
 [40, 41, 42, 43]]

In [22]:
iterable_dataset = MyIterableDataset(data_list, 4)
loader = DataLoader(dataset=iterable_dataset, batch_size=None)

for batch in islice(loader, 12):
    print(batch)

[27, 27, 12, 40]
[28, 28, 13, 41]
[29, 29, 14, 42]
[12, 40, 15, 43]
[13, 41, 16, 31]
[14, 42, 17, 32]
[15, 43, 31, 33]
[16, 12, 32, 34]
[17, 13, 33, 35]
[31, 14, 34, 36]
[32, 15, 35, 37]
[33, 16, 36, 38]


In [32]:
from itertools import chain, islice
import random


class MyIterableDataset(IterableDataset):
    
    def __init__(self, data_list):
        self.data_list = data_list
        
    @property
    def shuffled_data_list(self):
        return random.sample(self.data_list, len(self.data_list))
        
    def process_data(self, data):
        for x in data:
            yield x
            
    def get_stream(self, data_list):
        return chain.from_iterable(map(self.process_data, data_list))
    
    def __iter__(self):
        return self.get_stream(self.shuffled_data_list)

In [33]:
data_list

[[12, 13, 14, 15, 16, 17],
 [27, 28, 29],
 [31, 32, 33, 34, 35, 36, 37, 38, 39],
 [40, 41, 42, 43]]

In [34]:
iterable_dataset = MyIterableDataset(data_list)

In [35]:
loader = DataLoader(iterable_dataset, batch_size=4)

In [36]:
for i, batch in enumerate(loader):
    print(i, '\t', batch)

0 	 tensor([27, 28, 29, 12])
1 	 tensor([13, 14, 15, 16])
2 	 tensor([17, 31, 32, 33])
3 	 tensor([34, 35, 36, 37])
4 	 tensor([38, 39, 40, 41])
5 	 tensor([42, 43])


### Iterable for MAD dataset

In [29]:
import pandas as pd
import numpy as np
import os, sys

  return f(*args, **kwds)
  return f(*args, **kwds)


In [30]:
# GLOBALS
LOCAL_DIR = '/Users/varunn/Documents/'
PROJ_DIR = os.path.join(LOCAL_DIR, 'ExternalTest/MAD')
DATA_DIR = os.path.join(LOCAL_DIR, 'ExternalTest_Data/MAD')
RAW_DATA_DIR = os.path.join(DATA_DIR, 'raw')
RAW_DATA_TEST_DIR = os.path.join(DATA_DIR, 'raw_data_test')
INTERIM_DATA_DIR = os.path.join(DATA_DIR, 'interim')
MODEL_DIR = os.path.join(DATA_DIR, 'model')
PREDICTION_DIR = os.path.join(DATA_DIR, 'prediction')
RAW_INP_FN = os.path.join(RAW_DATA_DIR, '000{}_part_0{}.gz')

In [31]:
inp_fn = os.path.join(RAW_DATA_TEST_DIR, 'sample_data_{}.gz')

In [48]:
from itertools import chain


class MyIterableDataset(IterableDataset):
    
    def __init__(self, file_name, num_files):
        self.inp_fn = file_name
        self.num_files = num_files
        self.files = [self.inp_fn.format(i+1) for i in
                      range(self.num_files)]
        
    def read_file(self, fn):
        df = pd.read_csv(fn, compression='gzip')
        return df
        
    def process_data(self, fn):
        data = self.read_file(fn)
        for row in data.itertuples():
            yield (row[1], row[2], row[3], row[4], row[5], row[6], row[7])
            
    def get_stream(self, files):
        return chain.from_iterable(map(self.process_data, files))
    
    def __iter__(self):
        return self.get_stream(self.files)

In [49]:
iterable_dataset = MyIterableDataset(inp_fn, 2)

In [50]:
loader = DataLoader(iterable_dataset, batch_size=2)

In [53]:
for i, batch in enumerate(loader):
    print(i, '\t', batch)
    print('\n\n')

0 	 [('4d5c46572f07de33c42db36aad17a40a', '055dd2d541b834f3df2ea5e2a6585592'), ('pageView', 'pageView'), ('719f0a3061ad449200417b52fde5b7a8', 'aea664dad45b5ed1676c8c0e52f2f38b'), tensor([1552374574, 1551400571]), ('f6bf60f9e64ea4af52dfb058e00e1f8b', '6fa8f87cc30fb02e7e4ddf73d072c1cc'), ('b1e8e536f770d6fc0f56b547a61fffa4', 'f0dcdc8116582cc2687a6c2f47241ecb'), tensor([ 949., 2395.], dtype=torch.float64)]



1 	 [('3f5e575613a308dc10f2d90115fcb5e7', '83b8ababec8f375ebfd4f1b4f334efb5'), ('pageView', 'pageView'), ('f56a37d9a8b045eb585db66f1b6ab83f', '15e75825b396262753bf9e1285e78720'), tensor([1555249274, 1554862662]), ('6242cac2b5ecad61341565481aeeff06', '0812f4a121588bf7a2d40c1c2c1d0cee'), ('1468ed5708e197374bfa55ce39734c06', '2316a9480261760c3e57cfbd89e5b758'), tensor([1599., 1699.], dtype=torch.float64)]



2 	 [('811234c7249af880bfd35a8619472abb', 'dcf96960a751352995042c94164ef3ea'), ('pageView', 'addToCart'), ('42499fd23524fa280f2c07f6e4d4e35d', '26f146680d5c93eabbff56b7a95e46b3'), te