In [66]:
import torch

In [67]:
class RNN:
    def __init__(self, vocab_size, num_hiddens):
        self.vocab_size, self.num_hiddens = vocab_size, num_hiddens
        self.params = self._init_params()
        
    def _init_params(self):
        num_inputs = num_outputs = self.vocab_size; num_hiddens = self.num_hiddens
        def normal(shape):
            return torch.randn(size=shape, device=device) * 0.01
        W_xh = normal((num_inputs, num_hiddens)); 
        W_hh = normal((num_hiddens, num_hiddens))
        b_h = torch.zeros(num_hiddens, device=device)
        W_hq = normal((num_hiddens, num_outputs)); 
        b_q = torch.zeros(num_outputs, device=device)
        params = [W_xh, W_hh, b_h, W_hq, b_q]

        for param in params:
            param.requires_grad_(True)
        return params
    
    def init_state(self, batch_size):
        """初始化隐藏状态"""
        return (torch.zeros((batch_size, self.num_hiddens), device=device), )
    
    def _forward(self, inputs, state):    
        W_xh, W_hh, b_h, W_hq, b_q = self.params
        H, = state
        outputs = []
        for X in inputs:
            H = torch.tanh(X @ W_xh + H @ W_hh + b_h)
            Y = H @ W_hq + b_q
            outputs.append(Y)
        return torch.cat(outputs, dim=0), (H, )
    
    def __call__(self, X, state):
        inputs = torch.nn.functional.one_hot(X.T, self.vocab_size).type(torch.float32)
        return self._forward(inputs, state)
    
    def grad_clipping(self, theta):
        norm = torch.sqrt(sum([torch.sum(p) for p in self.params]))
        if norm > theta:
            for p in self.params:
                p.grad[:] *= theta / norm

In [68]:
def seq_data_iter_random(corpus, batch_size, num_steps):  #@save
    """使用随机抽样生成一个小批量子序列"""
    # 从随机偏移量开始对序列进行分区，随机范围包括num_steps-1
    corpus = corpus[random.randint(0, num_steps - 1):]
    # 减去1，是因为我们需要考虑标签
    num_subseqs = (len(corpus) - 1) // num_steps
    # 长度为num_steps的子序列的起始索引
    initial_indices = list(range(0, num_subseqs * num_steps, num_steps))
    # 在随机抽样的迭代过程中，
    # 来自两个相邻的、随机的、小批量中的子序列不一定在原始序列上相邻
    random.shuffle(initial_indices)

    def data(pos):
        # 返回从pos位置开始的长度为num_steps的序列
        return corpus[pos: pos + num_steps]

    num_batches = num_subseqs // batch_size
    for i in range(0, batch_size * num_batches, batch_size):
        # 在这里，initial_indices包含子序列的随机起始索引
        initial_indices_per_batch = initial_indices[i: i + batch_size]
        X = [data(j) for j in initial_indices_per_batch]
        Y = [data(j + 1) for j in initial_indices_per_batch]
        yield torch.tensor(X), torch.tensor(Y)
        
def seq_data_iter_sequential(corpus, batch_size, num_steps):  #@save
    """使用顺序分区生成一个小批量子序列"""
    # 从随机偏移量开始划分序列
    offset = random.randint(0, num_steps)
    num_tokens = ((len(corpus) - offset - 1) // batch_size) * batch_size
    Xs = torch.tensor(corpus[offset: offset + num_tokens])
    Ys = torch.tensor(corpus[offset + 1: offset + 1 + num_tokens])
    Xs, Ys = Xs.reshape(batch_size, -1), Ys.reshape(batch_size, -1)
    num_batches = Xs.shape[1] // num_steps
    for i in range(0, num_steps * num_batches, num_steps):
        X = Xs[:, i: i + num_steps]
        Y = Ys[:, i: i + num_steps]
        yield X, Y

In [69]:
def tokenize(lines, token='word'): 
    """将文本行拆分为单词或字符词元"""
    if token == 'word':
        return [line.split() for line in lines]
    elif token == 'char':
        return [list(line) for line in lines]
    else:
        print('错误：未知词元类型：' + token)

def load_corpus_vocab(max_tokens=-1): 
    """词元索引列表和词表"""
    lines = texts
    tokens = d2l.tokenize(lines)
    vocab = d2l.Vocab(tokens)
    corpus = [vocab[token] for line in tokens for token in line]
    if max_tokens:
        corpus = corpus[:max_tokens]
    return corpus, vocab

class SeqDataLoader: 
    """加载序列数据的迭代器"""
    def __init__(self, batch_size, num_steps, use_random_iter, max_tokens):
        if use_random_iter:
            self.data_iter_fn = d2l.seq_data_iter_random
        else:
            self.data_iter_fn = d2l.seq_data_iter_sequential
        self.corpus, self.vocab = load_corpus_vocab(max_tokens)
        self.batch_size, self.num_steps = batch_size, num_steps

    def __iter__(self):
        return self.data_iter_fn(self.corpus, self.batch_size, self.num_steps)
    
def load_data(batch_size, num_steps,  #@save
                           use_random_iter=False, max_tokens=10000):
    """返回迭代器和词表"""
    data_iter = SeqDataLoader(
        batch_size, num_steps, use_random_iter, max_tokens)
    return data_iter, data_iter.vocab

In [70]:
batch_size, num_steps, device = 32, 5, 'cpu'
texts = ['I am a cat'] * 1000
train_iter, vocab = load_data(batch_size, num_steps, max_tokens=10000)

In [71]:
num_epochs, lr = 10, 1
net = RNN(len(vocab), 256)
updater = torch.optim.SGD(net.params, lr)
loss = torch.nn.CrossEntropyLoss()
for epoch in range(num_epochs):
    state = None
    metric = [0, 0]
    for X, Y in train_iter:
        if state is None:
            state = net.init_state(batch_size=X.shape[0])
        else:
            for s in state: s.detach_()
        y_hat, state = net(X, state)
        y = Y.T.reshape(-1)
        l = loss(y_hat, y.long()).mean()
        updater.zero_grad()
        l.backward()
        net.grad_clipping(1)
        updater.step()
        metric[0] += l * y.numel(); metric[1] += y.numel()
    print('epoch %d 困惑度 %f' % (epoch + 1, torch.exp(metric[0] / metric[1])))

epoch 1 困惑度 1.748636
epoch 2 困惑度 1.007764
epoch 3 困惑度 1.005285
epoch 4 困惑度 1.002974
epoch 5 困惑度 1.003444
epoch 6 困惑度 1.002012
epoch 7 困惑度 1.001384
epoch 8 困惑度 1.001039
epoch 9 困惑度 1.000825
epoch 10 困惑度 1.006340


In [72]:
def predict(prefix, num_preds, net, vocab, device):
    state = net.init_state(batch_size=1)
    outputs = [vocab[prefix[0]]]
    get_input = lambda: torch.tensor([outputs[-1]], device=device).reshape(1, 1)
    for y in prefix[1:]:
        _, state = net(get_input(), state)
        outputs.append(vocab[y])
    for _ in range(num_preds):
        y, state = net(get_input(), state)
        outputs.append(int(y.argmax(dim=1).reshape(1, 1)))
    return ' '.join([vocab.idx_to_token[i] for i in outputs])

In [73]:
predict('I am a'.split(' '), 1, net, vocab, device)

'I am a cat'