In [None]:
import torch
import torch.nn as nn

class LSTMCell(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers = 1, dropout = 0.1):
        super(LSTMCell, self).__init__()

        self.num_layers = num_layers
        self.dropout = nn.Dropout(p=dropout)

        ih, hh = [], []
        for i in range(num_layers):
            if i==0:
                ih.append(nn.Linear(input_size, 4 * hidden_size))
                hh.append(nn.Linear(hidden_size, 4 * hidden_size))
            else:
                ih.append(nn.Linear(hidden_size, 4 * hidden_size))
                hh.append(nn.Linear(hidden_size, 4 * hidden_size))
        self.w_ih = nn.ModuleList(ih)
        self.w_hh = nn.ModuleList(hh)

    def forward(self, input, hidden):
        if hidden[0].shape[0] != self.num_layers:
            hidden = (
                torch.tile(hidden[0], [self.num_layers,1,1]),
                torch.tile(hidden[1], [self.num_layers,1,1]))

        hy, cy = [], []
        for i in range(self.num_layers):
            hx, cx = hidden[0][i], hidden[1][i]
            gates = self.w_ih[i](input) + self.w_hh[i](hx)
            i_gate, f_gate, c_gate, o_gate = gates.chunk(4, 1)
            i_gate = torch.sigmoid(i_gate)
            f_gate = torch.sigmoid(f_gate)
            c_gate = torch.tanh(c_gate)
            o_gate = torch.sigmoid(o_gate)
            ncx = (f_gate * cx) + (i_gate * c_gate)
            nhx = o_gate * torch.tanh(ncx)
            cy.append(ncx)
            hy.append(nhx)
            input = self.dropout(nhx)

        hy, cy = torch.stack(hy, 0), torch.stack(cy, 0)  # number of layer * batch * hidden
        return hy, cy

lstm = LSTMCell(10, 20, 2)
input = torch.randn(5, 3, 10) # [Sequence Length, Batch Size, Input Size]
hx = torch.randn(3, 20) # [Batch Size, Hidden Size]
cx = torch.randn(3, 20) # [Batch Size, Cell Size]
output = []
for i in range(input.size()[0]):
    hx, cx = lstm(input[i], (hx, cx))
    output.append(hx)
output = torch.stack(output, dim=0)

#著作权归作者所有。
#商业转载请联系作者获得授权,非商业转载请注明出处。
#原文: https://0809zheng.github.io/2020/03/07/RNN.html