In [89]:
import torch
from torch.nn import functional as F
from d2l import torch as d2l

$$ R_{t} = \sigma(X_{t}W_{xr} + H_{t-1}W_{hr} + b_{r}) $$
$$ Z_{t} = \sigma(X_{t}W_{xz} + H_{t-1}W_{hz} + b_{z}) $$
$$ \hat{H}_{t} = tanh(X_{t}W_{xh}+(R_{t} \odot H_{t-1})W_{hh} + b_{h}) $$
$$ H_{t} = Z_{t} \odot H_{t-1} + (1 - Z_{t}) \odot \hat{H}_{t} $$
$$ Y_{t} = H_{t}W_{hq} +b_{q} $$

In [126]:
class GRU:
    def __init__(self, vocab_size, num_hiddens, get_params, init_state, forward_fn):
        self.vocab_size, self.num_hiddens = vocab_size, num_hiddens
        self.params = self.init_params()
    
    def init_params(self):
        num_inputs = num_outputs = self.vocab_size
        num_hiddens = self.num_hiddens
        def normal(shape):
            return torch.randn(size=shape, device=device) * 0.01
        def three():
            return (normal((num_inputs, num_hiddens)), 
                    normal((num_hiddens, num_hiddens)), 
                    torch.zeros(num_hiddens, device=device))
        W_xr, W_hr, b_r = three()
        W_xz, W_hz, b_z = three()
        W_xh, W_hh, b_h = three()
        W_hq = normal((num_hiddens, num_outputs)); b_q = torch.zeros(num_outputs, device=device)
        params = [W_xr, W_hr, b_r, W_xz, W_hz, b_z, W_xh, W_hh, b_h, W_hq, b_q]
        for param in params:
            param.requires_grad_(True)
        return params
    
    def init_state(self, batch_size):
        return (torch.zeros((batch_size, self.num_hiddens), device=device),)
    
    def forward(self, inputs, state):
        W_xr, W_hr, b_r, W_xz, W_hz, b_z, W_xh, W_hh, b_h, W_hq, b_q = self.params
        outputs = []; H, = state
        for X in inputs:
            R = torch.sigmoid(X @ W_xr + H @ W_hr + b_r)
            Z = torch.sigmoid(X @ W_xz + H @ W_hz + b_z)
            H_hat = torch.tanh(X @ W_xh + (R * H) @ W_hh + b_h)
            H = Z * H + (1 - Z) * H_hat
            Y = H @ W_hq + b_q
            outputs.append(Y)
        return torch.cat(outputs, dim=0), (H, )
    
    def __call__(self, X, state):
        inputs = F.one_hot(X.T, self.vocab_size).type(torch.float32)
        return self.forward(inputs, state)
    
    def grad_clipping(self, theta):
        norm = torch.sqrt(sum([torch.sum(param.grad ** 2) for param in self.params]))
        if norm > theta:
            for param in self.params:
                param.grad[:] *= theta / norm      

def train(net, train_iter, lr, num_epochs):
    loss = nn.CrossEntropyLoss()
    updater = torch.optim.SGD(net.params, lr)
    for epoch in tqdm(range(num_epochs), ncols=100):
        state = None
        metrics = [0, 0]
        for X, Y in train_iter:
            if state is None:
                state = net.init_state(X.shape[0])
            for s in state: s.detach_()
            y_hat, state = net(X, state)
            y = Y.T.reshape(-1)
            l = loss(y_hat, y.long()).mean()
            l.backward()
            net.grad_clipping(1)
            updater.step()
            metrics[0] += l * y.numel(); metrics[1] += y.numel()
        l = torch.exp(metrics[0] / metrics[1])
    print('困惑度 %f' % l)

In [127]:
def seq_data_iter_sequential(corpus, batch_size, num_steps):  #@save
    """使用顺序分区生成一个小批量子序列"""
    offset = random.randint(0, num_steps)
    num_tokens = ((len(corpus) - offset - 1) // batch_size) * batch_size
    Xs = torch.tensor(corpus[offset: offset + num_tokens])
    Ys = torch.tensor(corpus[offset + 1: offset + 1 + num_tokens])
    Xs, Ys = Xs.reshape(batch_size, -1), Ys.reshape(batch_size, -1)
    num_batches = Xs.shape[1] // num_steps
    for i in range(0, num_steps * num_batches, num_steps):
        X = Xs[:, i: i + num_steps]
        Y = Ys[:, i: i + num_steps]
        yield X, Y

def tokenize(lines, token='word'): 
    """将文本行拆分为单词或字符词元"""
    if token == 'word':
        return [line.split() for line in lines]
    elif token == 'char':
        return [list(line) for line in lines]
    else:
        print('错误：未知词元类型：' + token)

def load_corpus_vocab(max_tokens=-1): 
    """词元索引列表和词表"""
    lines = texts
    tokens = d2l.tokenize(lines)
    vocab = d2l.Vocab(tokens)
    corpus = [vocab[token] for line in tokens for token in line]
    if max_tokens:
        corpus = corpus[:max_tokens]
    return corpus, vocab

class SeqDataLoader: 
    """加载序列数据的迭代器"""
    def __init__(self, batch_size, num_steps, use_random_iter, max_tokens):
        if use_random_iter:
            self.data_iter_fn = d2l.seq_data_iter_random
        else:
            self.data_iter_fn = d2l.seq_data_iter_sequential
        self.corpus, self.vocab = load_corpus_vocab(max_tokens)
        self.batch_size, self.num_steps = batch_size, num_steps

    def __iter__(self):
        return self.data_iter_fn(self.corpus, self.batch_size, self.num_steps)
    
def load_data(batch_size, num_steps,  #@save
                           use_random_iter=False, max_tokens=10000):
    """返回迭代器和词表"""
    data_iter = SeqDataLoader(
        batch_size, num_steps, use_random_iter, max_tokens)
    return data_iter, data_iter.vocab

In [128]:
batch_size, num_steps = 32, 35
# train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)
texts = ['I love cat'] * 1000
train_iter, vocab = load_data(batch_size, num_steps, max_tokens=10000)
vocab_size, num_hiddens, device = len(vocab), 256, 'cpu'
num_epochs, lr = 10, 1
net = GRU(vocab_size, num_hiddens, get_params, init_gru_state, gru)
train(net, train_iter, lr, num_epochs)

100%|███████████████████████████████████████████████████████████████| 10/10 [00:01<00:00,  8.27it/s]

困惑度 1.000001





In [129]:
def predict(prefix, num_preds, net, vocab):
    state = net.init_state(1); outputs = [vocab[prefix[0]]]
    get_input = lambda: torch.tensor([outputs[-1]], device=device).reshape(1, 1)
    for y in prefix[1:]:
        _, state = net(get_input(), state)
        outputs.append(vocab[y])
    for _ in range(num_preds):
        y, state = net(get_input(), state)
        outputs.append(int(y.argmax(dim=1).reshape(1, 1)))
    return ' '.join([vocab.idx_to_token[i] for i in outputs])

In [130]:
predict('I love'.split(' '), 1, net, vocab)

'I love cat'