In [3]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

class RNNModel(nn.Module):
    def __init__(self, rnn_type, ntoken, ninp, nhid, nlayers, dropout=0.5, tie_weights=False):
        super(RNNModel, self).__init__()
        self.rnn_type = rnn_type
        self.ntoken = ntoken
        self.ninp = ninp
        self.nhid = nhid
        self.nlayers = nlayers
        self.tie_weights = tie_weights

        self.encoder = nn.Embedding(ntoken, ninp)

        self.dropout = nn.Dropout(dropout)
        if rnn_type in ["LSTM", "GRU"]:
            self.rnn = getattr(nn, rnn_type)(ninp, nhid, nlayers, dropout=dropout)
        else:
            try:
                nonlinearity = {"RNN_TANH":"tanh", "RNN_RELU":"relu"}[rnn_type]
            except KeyError:
                raise ValueError("""An invalid option for `--model` was supplied,
                                 options are ['LSTM', 'GRU', 'RNN_TANH' or 'RNN_RELU']""")
            self.rnn = nn.RNN(ninp, nhid, nlayers, nonlinearity=nonlinearity, dropout=dropout)

        self.decoder = nn.Linear(nhid, ntoken)

        if tie_weights:
            if nhid != ninp:
                raise ValueError('When using the tied flag, nhid must be equal to emsize')
            self.decoder.weight = self.encoder.weight
    
    def init_weights(self, initrange):
        # TODO 为什么不需要初始化 encoder 的 bias，是没有这个参数吗
        nn.init.uniform_(self.encoder.weight, -initrange, initrange)
        nn.init.zeros_(self.decoder.bias)
        nn.init.uniform_(self.decoder.weight, -initrange, initrange)

    def forward(self, input, hidden):
        emb = self.encoder(input)
        input = self.dropout(emb)
        output, hidden = self.rnn(input, hidden)
        decoded = self.decoder(output)
        decoded = decoded.reshape(-1, self.token)
        return F.log_softmax(decoded, dim=1), hidden

    def init_hidden(self, batch_size):
        weight = next(self.parameters())
        if self.rnn_type == "LSTM":
            return (
                weight.new_zeros(self.nlayers, batch_size, self.nhid),
                weight.new_zeros(self.nlayers, batch_size, self.nhid)
            )
        else:
            return weight.new_zeros(self.nlayers, batch_size, self.nhid)

In [5]:
rnn = RNNModel('LSTM', 10086, 512, 512, 2, dropout=0.5, tie_weights=True)

In [21]:
rnn.encoder.weight.new_zeros(2)

tensor([0., 0.])