In [1]:
import torch

  from .autonotebook import tqdm as notebook_tqdm


# RNNCell实现

In [2]:
input_size = 4
hidden_size = 4
batch_size = 1

idx2char = ['e', 'h', 'l', 'o']
x_data = [1, 0, 2, 2, 3]    # hello
y_data = [3, 1, 2, 3, 2]    # ohlol

# 设置一个one-hot编码的查找，使后续one-hot编码更加便利
one_hot_lookup = [[1, 0, 0, 0],
                  [0, 1, 0, 0],
                  [0, 0, 1, 0],
                  [0, 0, 0, 1]]
x_one_hot = [one_hot_lookup[x] for x in x_data]  # seg_len * input_size

In [4]:
inputs = torch.Tensor(x_one_hot).view(-1, batch_size, input_size)  # 要把inputs变为seg_len * batch_size * input_size的形式
labels = torch.LongTensor(y_data).view(-1, 1)

In [11]:
inputs.shape, labels.shape

(torch.Size([5, 1, 4]), torch.Size([5, 1]))

In [12]:
class Model(torch.nn.Module):
    def __init__(self, input_size, hidden_size, batch_size):
        super(Model, self).__init__()
        self.batch_size = batch_size
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.rnncell = torch.nn.RNNCell(input_size=self.input_size, hidden_size=self.hidden_size)

    def forward(self, input, hidden):
        hidden = self.rnncell(input, hidden)
        return hidden

    def init_hidden(self):
        return torch.zeros(self.batch_size, self.hidden_size)

In [13]:
net = Model(input_size, hidden_size, batch_size)

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.1)

In [14]:
for epoch in range(15):
    loss = 0
    optimizer.zero_grad()
    hidden = net.init_hidden()
    print('Predicted string: ', end='')
    for input, label in zip(inputs, labels):  # inputs：seg_len * batch_size * input_size；labels：
        hidden = net.forward(input, hidden)
        loss += criterion(hidden, label)  # 要把每个字母的loss累加
        _, idx = hidden.max(dim=1)
        print(idx2char[idx.item()], end='')
    loss.backward()
    optimizer.step()
    print(', Epoch [%d/15] loss=%.4f' % (epoch+1, loss.item()))

Predicted string: hlhhh, Epoch [1/15] loss=6.9304
Predicted string: hollh, Epoch [2/15] loss=5.8971
Predicted string: hhlll, Epoch [3/15] loss=5.5015
Predicted string: ohlll, Epoch [4/15] loss=5.3366
Predicted string: ohlll, Epoch [5/15] loss=5.1905
Predicted string: ohlll, Epoch [6/15] loss=4.9861
Predicted string: ohlll, Epoch [7/15] loss=4.6878
Predicted string: ohhll, Epoch [8/15] loss=4.3736
Predicted string: ohhll, Epoch [9/15] loss=4.1609
Predicted string: ohhll, Epoch [10/15] loss=3.7910
Predicted string: ohlll, Epoch [11/15] loss=3.3479
Predicted string: ohlll, Epoch [12/15] loss=2.9904
Predicted string: ohlll, Epoch [13/15] loss=2.7354
Predicted string: ohlll, Epoch [14/15] loss=2.5705
Predicted string: ohlll, Epoch [15/15] loss=2.4694


In [1]:
import torch
torch.rand(3, 4)

  from .autonotebook import tqdm as notebook_tqdm


tensor([[0.1158, 0.4469, 0.6086, 0.8961],
        [0.2050, 0.8445, 0.6693, 0.7233],
        [0.1295, 0.4225, 0.5397, 0.5874]])