$$ I_{t} = \sigma(X_{t}W_{xi} + H_{t-1}W_{hi} + b_{i}) $$
$$ F_{t} = \sigma(X_{t}W_{xf} + H_{t-1}W_{hf} + b_{f}) $$
$$ O_{t} = \sigma(X_{t}W_{xo} + H_{t-1}W_{ho} + b_{o}) $$
$$ \hat{C}_{t} = tanh(X_{t}W_{xc} + H_{t-1}W_{hc} + b_{c}) $$
$$ C_{t} = F_{t} \odot C_{t-1} + I_{t} \odot \hat{C}_{t} $$
$$ H_{t} = O_{t} \odot tanh(C_{t}) $$
$$ Y_{t} = H_{t}W_{hq} + b_{q} $$

In [15]:
import torch
from torch.nn import functional as F
from tqdm import tqdm
from d2l import torch as d2l

In [17]:
class LSTM:
    def __init__(self, vocab_size, num_hiddens):
        self.vocab_size, self.num_hiddens = vocab_size, num_hiddens
        self.params = self.init_params()
        
    def init_params(self):
        num_inputs = num_outputs = self.vocab_size; num_hiddens = self.num_hiddens
        def normal(shape):
            return torch.randn(size=shape, device=device) * 0.01
        def three():
            return (normal((num_inputs, num_hiddens)), normal((num_hiddens, num_hiddens)), torch.zeros(num_hiddens, device=device))
        W_xi, W_hi, b_i = three(); W_xf, W_hf, b_f = three(); W_xo, W_ho, b_o = three(); W_xc, W_hc, b_c = three();
        W_hq, b_q = normal((num_hiddens, num_outputs)), torch.zeros(num_outputs, device=device)
        params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c, W_hq, b_q]
        for param in params:
            param.requires_grad_(True)
        return params
    
    def init_state(self, batch_size):
        return (torch.zeros((batch_size, self.num_hiddens), device=device), torch.zeros((batch_size, self.num_hiddens), device=device))
    
    def forward(self, inputs, state):
        C, H = state; outputs = []
        W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c, W_hq, b_q = self.params
        for X in inputs:
            I = torch.sigmoid(X @ W_xi + H @ W_hi + b_i)
            F = torch.sigmoid(X @ W_xf + H @ W_hf + b_f)
            O = torch.sigmoid(X @ W_xo + H @ W_ho + b_o)
            C_hat = torch.tanh(X @ W_xc + H @ W_hc + b_c)
            C = F * C + I * C_hat
            H = O * torch.tanh(C)
            Y = H @ W_hq + b_q
            outputs.append(Y)
        return torch.cat(outputs, dim=0), (C, H)
    
    def __call__(self, X, state):
        inputs = F.one_hot(X.T, self.vocab_size).type(torch.float32)
        return self.forward(inputs, state)
    
    def grad_clipping(self, theta):
        norm = torch.sqrt(sum([torch.sum(p) for p in self.params]))
        if norm > theta:
            for p in self.params:
                p.grad[:] *= theta / norm

In [25]:
def seq_data_iter_sequential(corpus, batch_size, num_steps):  #@save
    """使用顺序分区生成一个小批量子序列"""
    offset = random.randint(0, num_steps)
    num_tokens = ((len(corpus) - offset - 1) // batch_size) * batch_size
    Xs = torch.tensor(corpus[offset: offset + num_tokens])
    Ys = torch.tensor(corpus[offset + 1: offset + 1 + num_tokens])
    Xs, Ys = Xs.reshape(batch_size, -1), Ys.reshape(batch_size, -1)
    num_batches = Xs.shape[1] // num_steps
    for i in range(0, num_steps * num_batches, num_steps):
        X = Xs[:, i: i + num_steps]
        Y = Ys[:, i: i + num_steps]
        yield X, Y

def tokenize(lines, token='word'): 
    """将文本行拆分为单词或字符词元"""
    if token == 'word':
        return [line.split() for line in lines]
    elif token == 'char':
        return [list(line) for line in lines]
    else:
        print('错误：未知词元类型：' + token)

def load_corpus_vocab(max_tokens=-1): 
    """词元索引列表和词表"""
    lines = texts
    tokens = d2l.tokenize(lines)
    vocab = d2l.Vocab(tokens)
    corpus = [vocab[token] for line in tokens for token in line]
    if max_tokens:
        corpus = corpus[:max_tokens]
    return corpus, vocab

class SeqDataLoader: 
    """加载序列数据的迭代器"""
    def __init__(self, batch_size, num_steps, use_random_iter, max_tokens):
        if use_random_iter:
            self.data_iter_fn = d2l.seq_data_iter_random
        else:
            self.data_iter_fn = d2l.seq_data_iter_sequential
        self.corpus, self.vocab = load_corpus_vocab(max_tokens)
        self.batch_size, self.num_steps = batch_size, num_steps

    def __iter__(self):
        return self.data_iter_fn(self.corpus, self.batch_size, self.num_steps)
    
def load_data(batch_size, num_steps,  #@save
                           use_random_iter=False, max_tokens=10000):
    """返回迭代器和词表"""
    data_iter = SeqDataLoader(
        batch_size, num_steps, use_random_iter, max_tokens)
    return data_iter, data_iter.vocab

In [31]:
batch_size, num_steps, device = 32, 35, 'cpu'
# train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)
texts = ['I love cat'] * 1000
train_iter, vocab = load_data(batch_size, num_steps, max_tokens=10000)
num_epochs, num_hiddens = 100, 256
net = LSTM(len(vocab), num_hiddens)
loss = torch.nn.CrossEntropyLoss()
updater = torch.optim.SGD(net.params, lr=1)

In [41]:
for epoch in range(num_epochs):
    state = None; metrics = [0, 0]
    for X, Y in train_iter:
        if state is None:
            state = net.init_state(batch_size)
        for s in state: s.detach_()
        y_hat, state = net(X, state)
        y = Y.T.reshape(-1)
        l = loss(y_hat, y)
        l.backward()
        net.grad_clipping(1)
        updater.step()
        metrics[0] += l * y.numel(); metrics[1] += y.numel()
    if (epoch + 1) % (num_epochs // 10) == 0:
        print('epoch %d 困惑度 %f' % (epoch + 1, torch.exp(metrics[0] / metrics[1])))

epoch 10 困惑度 1.008207
epoch 20 困惑度 1.007645
epoch 30 困惑度 1.007135
epoch 40 困惑度 1.006697
epoch 50 困惑度 1.006306
epoch 60 困惑度 1.005967
epoch 70 困惑度 1.005650
epoch 80 困惑度 1.005362
epoch 90 困惑度 1.005101
epoch 100 困惑度 1.004853


In [42]:
def predict(prefix, num_preds):
    state = net.init_state(1)
    outputs = [vocab[prefix[0]]]
    get_input = lambda: torch.tensor(outputs[-1]).reshape(1, 1)
    for y in prefix[1:]:
        _, state = net(get_input(), state)
        outputs.append(vocab[y])
    for _ in range(num_preds):
        y, state = net(get_input(), state)
        outputs.append(int(y.argmax(dim=1)))
    return ' '.join([vocab.idx_to_token[i] for i in outputs])

In [43]:
predict('I love'.split(' '), 1)

'I love cat'