# 词批次

In [46]:
import re
from abc import abstractmethod, ABC
import numpy as np

## Foundation

### Tensor

In [47]:
class Tensor:

    def __init__(self, data):
        self.data = np.array(data)
        self.grad = 0
        self.gradient_fn = lambda: None
        self.parents = set()

    def backward(self):
        if self.gradient_fn:
            self.gradient_fn()

        for p in self.parents:
            p.backward()

    def shape(self):
        return self.data.shape

    def size(self):
        return np.prod(self.data.shape[1:])

    def __str__(self):
        return str(self.data)

    def __add__(self, other):
        p = Tensor(self.data + other.data)

        def gradient_fn():
            self.grad += p.grad
            other.grad += p.grad

        p.gradient_fn = gradient_fn
        p.parents = {self, other}
        return p

    def __mul__(self, other):
        p = Tensor(self.data * other.data)

        def gradient_fn():
            self.grad += p.grad * other.data
            other.grad += p.grad * self.data

        p.gradient_fn = gradient_fn
        p.parents = {self, other}
        return p

    def concat(self, other, axis):
        p = Tensor(np.concatenate([self.data, other.data], axis=axis))

        def gradient_fn():
            grad = np.split(p.grad, [self.data.shape[axis]], axis=axis)
            self.grad += grad[0]
            other.grad += grad[1]

        p.gradient_fn = gradient_fn
        p.parents = {self, other}
        return p

### Base Dataset

In [48]:
class Dataset(ABC):

    def __init__(self, batch_size=1):
        self.batch_size = batch_size
        self.load()
        self.train()

    @abstractmethod
    def load(self):
        pass

    def train(self):
        self.features = self.train_features
        self.labels = self.train_labels

    def eval(self):
        self.features = self.test_features
        self.labels = self.test_labels

    def shape(self):
        return Tensor(self.features).size(), Tensor(self.labels).size()

    def items(self):
        return Tensor(self.features), Tensor(self.labels)

    def __len__(self):
        return len(self.features) // self.batch_size

    def __getitem__(self, index):
        start = index * self.batch_size
        end = start + self.batch_size
        return Tensor(self.features[start: end]), Tensor(self.labels[start: end])

    @abstractmethod
    def estimate(self, predictions):
        pass

## Data

### LLM Dataset

In [49]:
class LLMDataset(Dataset):

    def __init__(self, filename, batch_size, stride=1):
        self.filename = filename
        self.stride = stride
        super().__init__(batch_size)

    def load(self):
        with open(self.filename, 'r', encoding='utf-8') as f:
            self.text = f.read().lower()

        self.vocabulary = sorted(set(self.split_text(self.text)))
        self.vocabulary.extend(['<|eos|>', '<|unk|>'])
        self.word2index = {word: index for index, word in enumerate(self.vocabulary)}
        self.index2word = {index: word for index, word in enumerate(self.vocabulary)}
        self.tokens = self.encode(self.text)
        self.batchs = []
        for i in range(1, len(self.tokens) - self.batch_size, self.stride):
            self.batchs.append(self.tokens[i: i + self.batch_size])

    @staticmethod
    def split_text(text):
        words = re.split(r'([,.:;?_!"()\']|\s)', text.lower())
        return [t.strip() for t in words if t.strip()]

    def train(self):
        self.features = []
        self.labels = []
        for i in range(1, len(self.batchs) * 8 // 10):
            self.features.append(self.batchs[i])
            self.labels.append(self.onehot(self.batchs[i + 1]))

    def eval(self):
        self.features = []
        self.labels = []
        for i in range(len(self.batchs) * 8 // 10 + 1, len(self.batchs) - 1):
            self.features.append(self.batchs[i])
            self.labels.append(self.onehot(self.batchs[i + 1]))

    def __len__(self):
        return len(self.features)

    def __getitem__(self, index):
        return Tensor(self.features[index]), Tensor(self.labels[index])

    def encode(self, text):
        words = self.split_text(text)
        words = [word if word in self.word2index else '<|unk|>' for word in words]
        return [self.word2index[word] for word in words]

    def decode(self, tokens):
        text = " ".join([self.index2word[index] for index in tokens])
        text = re.sub(r'\s+([,.:;?!)\]}>])', r'\1', text)
        text = re.sub(r'([([<{])\s+', r'\1', text)
        text = re.sub(r'(")\s+(.*?)\s+(")', r'\1\2\3', text)
        text = re.sub(r"(')\s+(.*?)\s+(')", r'\1\2\3', text)
        return text.strip()

    def onehot(self, tokens):
        ebd = np.zeros((len(tokens), len(self.vocabulary)))
        ebd[np.arange(len(tokens)), tokens] = 1
        return ebd

    def estimate(self, predictions):
        pass

## Configuration

### Context Size

In [50]:
CONTEXT_SIZE = 8

## Testing

### Estimating

In [51]:
dataset = LLMDataset('../one-day.txt', CONTEXT_SIZE)

for i in range(len(dataset)):
    features, labels = dataset[i]
    print("(Feature, Label): ", features.data, np.array([int(l.argmax()) for l in labels.data]))

(Feature, Label):  [113 226 125 145   5 264  70 137] [226 125 145   5 264  70 137 126]
(Feature, Label):  [226 125 145   5 264  70 137 126] [125 145   5 264  70 137 126 226]
(Feature, Label):  [125 145   5 264  70 137 126 226] [145   5 264  70 137 126 226 217]
(Feature, Label):  [145   5 264  70 137 126 226 217] [  5 264  70 137 126 226 217 178]
(Feature, Label):  [  5 264  70 137 126 226 217 178] [264  70 137 126 226 217 178 155]
(Feature, Label):  [264  70 137 126 226 217 178 155] [ 70 137 126 226 217 178 155 226]
(Feature, Label):  [ 70 137 126 226 217 178 155 226] [137 126 226 217 178 155 226 106]
(Feature, Label):  [137 126 226 217 178 155 226 106] [126 226 217 178 155 226 106   3]
(Feature, Label):  [126 226 217 178 155 226 106   3] [226 217 178 155 226 106   3  33]
(Feature, Label):  [226 217 178 155 226 106   3  33] [217 178 155 226 106   3  33 199]
(Feature, Label):  [217 178 155 226 106   3  33 199] [178 155 226 106   3  33 199 113]
(Feature, Label):  [178 155 226 106   3  33