# 特殊字符

In [38]:
import re
from abc import abstractmethod, ABC
import numpy as np

## Foundation

### Tensor

In [39]:
class Tensor:

    def __init__(self, data):
        self.data = np.array(data)
        self.grad = 0
        self.gradient_fn = lambda: None
        self.parents = set()

    def backward(self):
        if self.gradient_fn:
            self.gradient_fn()

        for p in self.parents:
            p.backward()

    def shape(self):
        return self.data.shape

    def size(self):
        return np.prod(self.data.shape[1:])

    def __str__(self):
        return str(self.data)

    def __add__(self, other):
        p = Tensor(self.data + other.data)

        def gradient_fn():
            self.grad += p.grad
            other.grad += p.grad

        p.gradient_fn = gradient_fn
        p.parents = {self, other}
        return p

    def __mul__(self, other):
        p = Tensor(self.data * other.data)

        def gradient_fn():
            self.grad += p.grad * other.data
            other.grad += p.grad * self.data

        p.gradient_fn = gradient_fn
        p.parents = {self, other}
        return p

    def concat(self, other, axis):
        p = Tensor(np.concatenate([self.data, other.data], axis=axis))

        def gradient_fn():
            grad = np.split(p.grad, [self.data.shape[axis]], axis=axis)
            self.grad += grad[0]
            other.grad += grad[1]

        p.gradient_fn = gradient_fn
        p.parents = {self, other}
        return p

### Base Dataset

In [40]:
class Dataset(ABC):

    def __init__(self, batch_size=1):
        self.batch_size = batch_size
        self.load()
        self.train()

    @abstractmethod
    def load(self):
        pass

    def train(self):
        self.features = self.train_features
        self.labels = self.train_labels

    def eval(self):
        self.features = self.test_features
        self.labels = self.test_labels

    def shape(self):
        return Tensor(self.features).size(), Tensor(self.labels).size()

    def items(self):
        return Tensor(self.features), Tensor(self.labels)

    def __len__(self):
        return len(self.features) // self.batch_size

    def __getitem__(self, index):
        start = index * self.batch_size
        end = start + self.batch_size
        return Tensor(self.features[start: end]), Tensor(self.labels[start: end])

    @abstractmethod
    def estimate(self, predictions):
        pass

## Data

### LLM Dataset

In [41]:
class LLMDataset(Dataset):

    def __init__(self, filename):
        self.filename = filename
        super().__init__()

    def load(self):
        with open(self.filename, 'r', encoding='utf-8') as f:
            self.text = f.read().lower()

        self.vocabulary = sorted(set(self.split_text(self.text)))
        self.vocabulary.extend(['<|eos|>', '<|unk|>'])
        self.word2index = {word: index for index, word in enumerate(self.vocabulary)}
        self.index2word = {index: word for index, word in enumerate(self.vocabulary)}

    @staticmethod
    def split_text(text):
        words = re.split(r'([,.:;?_!"()\']|\s)', text.lower())
        return [t.strip() for t in words if t.strip()]

    def train(self):
        pass

    def eval(self):
        pass

    def encode(self, text):
        words = self.split_text(text)
        words = [word if word in self.word2index else '<|unk|>' for word in words]
        return [self.word2index[word] for word in words]

    def decode(self, tokens):
        text = " ".join([self.index2word[index] for index in tokens])
        text = re.sub(r'\s+([,.:;?!)\]}>])', r'\1', text)
        text = re.sub(r'([([<{])\s+', r'\1', text)
        text = re.sub(r'(")\s+(.*?)\s+(")', r'\1\2\3', text)
        text = re.sub(r"(')\s+(.*?)\s+(')", r'\1\2\3', text)
        return text.strip()

    def estimate(self, predictions):
        pass

## Testing

### Estimating

In [42]:
dataset = LLMDataset('../one-day.txt')

sentences = ['"Be careful" his mother says quickly.',
             '"I will," Tom replies.']
ids = dataset.encode(' <|eos|> '.join(sentences))

print('Text: ', sentences)
print('Encode: ', ids)
print('Decode: ', dataset.decode(ids))

Text:  ['"Be careful" his mother says quickly.', '"I will," Tom replies.']
Encode:  [0, 27, 47, 0, 108, 138, 188, 266, 3, 265, 0, 112, 257, 2, 0, 238, 175, 3]
Decode:  "be careful" his mother says <|unk|>. <|eos|> "i will," tom replies.
