In [43]:
import math
import numpy as np

class TextDataloader:
    previous_source = torch.tensor([])
    previous_target = torch.tensor([])

    def __init__(self, dataset, max_seq_len, batch_size, shuffle=True, include_previous=True):
        self.max_seq_len = max_seq_len
        self.batch_size = batch_size

        # shuffle logic vars
        self.shuffle = shuffle
        self.chunk_len = max_seq_len * batch_size

        # get seqence order
        num_seqs = (len(dataset) - 1) // max_seq_len
        self.seq_order = np.array(range(num_seqs))
        if shuffle:
            np.random.shuffle(self.seq_order)

        # get source, target datasets, trim
        self.dataset = dataset
        self.source = self.shuffle_dataset(dataset[0: len(dataset) - 1])
        self.targets = self.shuffle_dataset(dataset[1: len(dataset)])
        self.num_batches = num_seqs // self.batch_size # trim off non-conforming batches

        if include_previous: # leftover data from previous run is included
            self.source = torch.cat([TextDataloader.previous_source, self.source])
            self.target = torch.cat([TextDataloader.previous_target, self.targets])
            self.num_batches = math.ceil(num_seqs/self.batch_size) # include non-conforming batches
        
        self.dataset_len = len(self.source)

    def __iter__(self):
        self.index = 0
        return self

    def __next__(self):
        if self.index > self.num_batches - 1:
            raise StopIteration

        i = self.index
        chunk_pos = i * self.chunk_len
        data = self.source[chunk_pos: chunk_pos + self.chunk_len]
        target = self.targets[chunk_pos: chunk_pos + self.chunk_len]
 
        num_batches = min(self.batch_size, (self.dataset_len - chunk_pos) // self.max_seq_len)
        if num_batches < self.batch_size:
            TextDataloader.previous_source = data
            TextDataloader.previous_target = target
            
        # if num_batches == 0:
            raise StopIteration

        self.index += 1
        print(self.batch_size)
        return self.batchify(data, target, num_batches)

    def batchify(self, data, target, num_batches):
        # Evenly divide the data across the batch_size batches.
        data = data.view(num_batches, -1).contiguous()
        target = target.view(num_batches, -1).contiguous()

        # shuffle data
        if self.shuffle:
            permutation = torch.randperm(data.size(0))
            data = data[permutation]
            target = target[permutation]

        # flatten targets
        target = target.reshape(-1)
        return data, target.reshape(-1)

    def shuffle_dataset(self, dataset):
        shuffled_dataset = map(lambda x: dataset[x * self.max_seq_len: (x + 1)* self.max_seq_len], self.seq_order)
        return torch.cat(list(shuffled_dataset))

    def reset_previous():
        TextDataloader.previous_source = torch.tensor([])
        TextDataloader.previous_target = torch.tensor([])

In [51]:
TextDataloader.reset_previous()

In [61]:
# test TextDataloader
import torch

length = 111
batch_size = 5
max_seq_len = 10


dataset = torch.arange(0, length)
dataloader = TextDataloader(dataset, max_seq_len, batch_size, False)

for batch in dataloader:
    data, targets = batch
    print(data)
    print(targets)
    # break

5
tensor([[100., 101., 102., 103., 104., 105., 106., 107., 108., 109.],
        [  0.,   1.,   2.,   3.,   4.,   5.,   6.,   7.,   8.,   9.],
        [ 10.,  11.,  12.,  13.,  14.,  15.,  16.,  17.,  18.,  19.],
        [ 20.,  21.,  22.,  23.,  24.,  25.,  26.,  27.,  28.,  29.],
        [ 30.,  31.,  32.,  33.,  34.,  35.,  36.,  37.,  38.,  39.]])
tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
        19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36,
        37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50])
5
tensor([[40., 41., 42., 43., 44., 45., 46., 47., 48., 49.],
        [50., 51., 52., 53., 54., 55., 56., 57., 58., 59.],
        [60., 61., 62., 63., 64., 65., 66., 67., 68., 69.],
        [70., 71., 72., 73., 74., 75., 76., 77., 78., 79.],
        [80., 81., 82., 83., 84., 85., 86., 87., 88., 89.]])
tensor([ 51,  52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
         65,  66,  67,  68,  69,  70,  71

In [60]:
import torch

length = 122
batch_size = 5
max_seq_len = 10

# num_seqs = math.ceil((length - 1) / max_seq_len)
num_seqs = (length - 1 ) // max_seq_len
# print(num_seqs)
seq_order = np.array(range(num_seqs))

np.random.shuffle(seq_order)
print(seq_order)

dataset = torch.arange(0, length)
shuffled_dataset = map(lambda x: dataset[x * max_seq_len: (x + 1)* max_seq_len], seq_order)
print(torch.cat(list(shuffled_dataset)))

def shuffle_dataset(dataset):
    shuffled_dataset = map(lambda x: dataset[x * max_seq_len: (x + 1)* max_seq_len], seq_order)
    return torch.cat(list(shuffled_dataset))

data = dataset[0: len(dataset) - 1]
targets = dataset[1: len(dataset)]
print(shuffle_dataset(data))
print(shuffle_dataset(targets))






[ 5  2 11  6  3  1  4  9 10  8  0  7]
tensor([ 50,  51,  52,  53,  54,  55,  56,  57,  58,  59,  20,  21,  22,  23,
         24,  25,  26,  27,  28,  29, 110, 111, 112, 113, 114, 115, 116, 117,
        118, 119,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,  30,  31,
         32,  33,  34,  35,  36,  37,  38,  39,  10,  11,  12,  13,  14,  15,
         16,  17,  18,  19,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,
         90,  91,  92,  93,  94,  95,  96,  97,  98,  99, 100, 101, 102, 103,
        104, 105, 106, 107, 108, 109,  80,  81,  82,  83,  84,  85,  86,  87,
         88,  89,   0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  70,  71,
         72,  73,  74,  75,  76,  77,  78,  79])
tensor([ 50,  51,  52,  53,  54,  55,  56,  57,  58,  59,  20,  21,  22,  23,
         24,  25,  26,  27,  28,  29, 110, 111, 112, 113, 114, 115, 116, 117,
        118, 119,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,  30,  31,
         32,  33,  34,  35,  36,  37,  38,  39,  10,  1

In [31]:
torch.cat([torch.tensor([], torch.tensor([0])])


tensor([0.])