In [5]:
import math
import numpy as np

class TextDataloader:
    def __init__(self, dataset, max_seq_len, batch_size, shuffle=True):
        self.max_seq_len = max_seq_len
        self.batch_size = batch_size

        # shuffle logic vars
        self.shuffle = shuffle
        self.chunk_len = max_seq_len * batch_size

        # get seqence order
        num_seqs = (length - 1) // max_seq_len
        self.seq_order = np.array(range(num_seqs))
        if shuffle:
            np.random.shuffle(self.seq_order)

        # get source, target datasets, trim
        self.dataset = dataset
        self.source = self.shuffle_dataset(dataset[0: len(dataset) - 1])
        self.targets = self.shuffle_dataset(dataset[1: len(dataset)])
        self.dataset_len = len(self.source)
        # self.num_batches = math.ceil(num_seqs/self.batch_size) # include non-conforming batches
        self.num_batches = num_seqs // self.batch_size # trim off non-conforming batches

    def __iter__(self):
        self.index = 0
        return self

    def __next__(self):
        if self.index > self.num_batches - 1:
            raise StopIteration

        i = self.index
        chunk_pos = i * self.chunk_len
        data = self.source[chunk_pos: chunk_pos + self.chunk_len]
        target = self.targets[chunk_pos: chunk_pos + self.chunk_len]
 
        num_batches = min(self.batch_size, (self.dataset_len - chunk_pos) // self.max_seq_len)
        if num_batches == 0:
            raise StopIteration

        if(len(data) != len(target)):
            # remove mismatched batch sizes
            data = data.narrow(0, 0, self.max_seq_len * (num_batches - 1))
            target = target.narrow(0, 0, self.max_seq_len * (num_batches - 1))

        self.index += 1

        return self.batchify(data, target, num_batches)

    def batchify(self, data, target, num_batches):
        # Evenly divide the data across the batch_size batches.
        data = data.view(num_batches, -1).contiguous()
        target = target.view(num_batches, -1).contiguous()

        # shuffle data
        if self.shuffle:
            permutation = torch.randperm(data.size(0))
            data = data[permutation]
            target = target[permutation]

        # flatten targets
        target = target.reshape(-1)
        return data, target.reshape(-1)

    def shuffle_dataset(self, dataset):
        shuffled_dataset = map(lambda x: dataset[x * max_seq_len: (x + 1)* max_seq_len], self.seq_order)
        return torch.cat(list(shuffled_dataset))

In [6]:
# test TextDataloader
import torch

length = 149
batch_size = 5
max_seq_len = 10


dataset = torch.arange(0, length)
dataloader = TextDataloader(dataset, max_seq_len, batch_size, True)

for batch in dataloader:
    data, targets = batch
    print(data)
    print(targets)
    # break

tensor([[100, 101, 102, 103, 104, 105, 106, 107, 108, 109],
        [  0,   1,   2,   3,   4,   5,   6,   7,   8,   9],
        [ 10,  11,  12,  13,  14,  15,  16,  17,  18,  19],
        [ 90,  91,  92,  93,  94,  95,  96,  97,  98,  99],
        [130, 131, 132, 133, 134, 135, 136, 137, 138, 139]])
tensor([101, 102, 103, 104, 105, 106, 107, 108, 109, 110,   1,   2,   3,   4,
          5,   6,   7,   8,   9,  10,  11,  12,  13,  14,  15,  16,  17,  18,
         19,  20,  91,  92,  93,  94,  95,  96,  97,  98,  99, 100, 131, 132,
        133, 134, 135, 136, 137, 138, 139, 140])
tensor([[ 60,  61,  62,  63,  64,  65,  66,  67,  68,  69],
        [ 40,  41,  42,  43,  44,  45,  46,  47,  48,  49],
        [ 20,  21,  22,  23,  24,  25,  26,  27,  28,  29],
        [ 70,  71,  72,  73,  74,  75,  76,  77,  78,  79],
        [110, 111, 112, 113, 114, 115, 116, 117, 118, 119]])
tensor([ 61,  62,  63,  64,  65,  66,  67,  68,  69,  70,  41,  42,  43,  44,
         45,  46,  47,  48,  49,  50,

In [1]:
import torch

length = 122
batch_size = 5
max_seq_len = 10

# num_seqs = math.ceil((length - 1) / max_seq_len)
num_seqs = (length - 1 ) // max_seq_len
# print(num_seqs)
seq_order = np.array(range(num_seqs))

np.random.shuffle(seq_order)
print(seq_order)

dataset = torch.arange(0, length)
shuffled_dataset = map(lambda x: dataset[x * max_seq_len: (x + 1)* max_seq_len], seq_order)
print(torch.cat(list(shuffled_dataset)))

def shuffle_dataset(dataset):
    shuffled_dataset = map(lambda x: dataset[x * max_seq_len: (x + 1)* max_seq_len], seq_order)
    return torch.cat(list(shuffled_dataset))

data = dataset[0: len(dataset) - 1]
targets = dataset[1: len(dataset)]
print(shuffle_dataset(data))
print(shuffle_dataset(targets))






NameError: name 'np' is not defined