# 第 5 章实现的 RNN 模型

In [None]:
class RNN(nn.Module):
    def __init__(self, word_count, embedding_size, hidden_size, output_size):
        super(RNN, self).__init__()
        
        self.hidden_size = hidden_size
        self.embedding = torch.nn.Embedding(word_count, embedding_size)
        self.i2s = nn.Linear(embedding_size + hidden_size, hidden_size)
        self.i2o = nn.Linear(embedding_size + hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, input_tensor, hidden):
        word_vector = self.embedding(input_tensor)
        combined = torch.cat((word_vector, hidden), 1)
        hidden = self.i2s(combined)
        output = self.i2o(combined)
        output = self.softmax(output)
        return output, hidden

    def initHidden(self):
        return torch.zeros(1, self.hidden_size)

# 使用 PyTorch 中的 RNN

In [None]:
import torch.nn as nn
class RNN(nn.Module):
    def __init__(self, word_count, embedding_size, hidden_size, output_size):
        super(RNN, self).__init__()
        self.hidden_size = hidden_size
        self.embedding = torch.nn.Embedding(word_count, embedding_size)
        self.rnn = nn.RNN(embedding_size, hidden_size, num_layers=1, bidirectional=False, batch_first=True)
        self.cls = nn.Linear(hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=0)

    def forward(self, input_tensor):
        word_vector = self.embedding(input_tensor)
        output = self.rnn(word_vector)[0][0][len(input_tensor)-1]
        output = self.cls(output)
        output = self.softmax(output)
        return output


In [None]:
def run_rnn(rnn, input_tensor):
    output = rnn(input_tensor.unsqueeze(dim=0))
    return output

In [None]:
def train(rnn, criterion, input_tensor, category_tensor):
    rnn.zero_grad()
    output = run_rnn(rnn, input_tensor)
    loss = criterion(output.unsqueeze(dim=0), category_tensor)
    loss.backward()
    # 根据梯度更新模型的参数
    for p in rnn.parameters():
        p.data.add_(p.grad.data, alpha=-learning_rate)
    return output, loss.item()
def evaluate(rnn, input_tensor):
    with torch.no_grad():
        output = run_rnn(rnn, input_tensor)
        return output
