# Word Batch

In [1]:
import re

In [2]:
class DataLoader:

    def __init__(self, filename, batch_size, stride):
        self.filename = filename
        self.batch_size = batch_size
        self.stride = stride

        with open(self.filename, 'r', encoding='utf-8') as f:
            self.text = f.read().lower()

        self.vocabulary = sorted(set(self.split_text(self.text)))
        self.vocabulary.extend(['<|eos|>', '<|unk|>'])
        self.word2index = {word: index for index, word in enumerate(self.vocabulary)}
        self.index2word = {index: word for index, word in enumerate(self.vocabulary)}
        self.tokens = self.encode(self.text)

        self.train()

    @staticmethod
    def split_text(text):
        words = re.split(r'([,.:;?_!"()\']|\s)', text.lower())
        return [t.strip() for t in words if t.strip()]

    def train(self):
        self.features = []
        self.labels = []
        for i in range(0, len(self.tokens) * 9 // 10 - self.batch_size,
                       self.stride):
            self.features.append(self.tokens[i: i + self.batch_size])
            self.labels.append(self.tokens[i + 1: i + self.batch_size + 1])

    def eval(self):
        self.features = []
        self.labels = []
        for i in range(len(self.tokens) * 9 // 10 - self.batch_size + 1,
                       len(self.tokens) - self.batch_size,
                       self.stride):
            self.features.append(self.tokens[i: i + self.batch_size])
            self.labels.append(self.tokens[i + 1: i + self.batch_size + 1])

    def __len__(self):  # 3
        return len(self.features)

    def __getitem__(self, index):  # 4
        return self.features[index], self.labels[index]

    def encode(self, text):
        words = self.split_text(text)
        words = [word if word in self.word2index else '<|unk|>' for word in words]
        return [self.word2index[word] for word in words]

    def decode(self, tokens):
        text = " ".join([self.index2word[index] for index in tokens])
        return re.sub(r'\s+([,.:;?_!"()\'])', r'\1', text)

In [3]:
dataset = DataLoader('../one-day.txt', 4, 1)

for i in range(len(dataset)):
    print("(Feature, Label): ", dataset[i])

(Feature, Label):  ([149, 59, 113, 226], [59, 113, 226, 125])
(Feature, Label):  ([59, 113, 226, 125], [113, 226, 125, 145])
(Feature, Label):  ([113, 226, 125, 145], [226, 125, 145, 5])
(Feature, Label):  ([226, 125, 145, 5], [125, 145, 5, 264])
(Feature, Label):  ([125, 145, 5, 264], [145, 5, 264, 70])
(Feature, Label):  ([145, 5, 264, 70], [5, 264, 70, 137])
(Feature, Label):  ([5, 264, 70, 137], [264, 70, 137, 126])
(Feature, Label):  ([264, 70, 137, 126], [70, 137, 126, 226])
(Feature, Label):  ([70, 137, 126, 226], [137, 126, 226, 217])
(Feature, Label):  ([137, 126, 226, 217], [126, 226, 217, 178])
(Feature, Label):  ([126, 226, 217, 178], [226, 217, 178, 155])
(Feature, Label):  ([226, 217, 178, 155], [217, 178, 155, 226])
(Feature, Label):  ([217, 178, 155, 226], [178, 155, 226, 106])
(Feature, Label):  ([178, 155, 226, 106], [155, 226, 106, 3])
(Feature, Label):  ([155, 226, 106, 3], [226, 106, 3, 33])
(Feature, Label):  ([226, 106, 3, 33], [106, 3, 33, 199])
(Feature, Label)