In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class ElmanRNN(nn.Module):
    def __init__(self, input_size, hidden_size, batch_first=False):
        super(ElmanRNN, self).__init__()
        self.rnn_cell = nn.RNNCell(input_size, hidden_size)
        
        self.batch_first = batch_first
        self.hidden_size = hidden_size
        
    def _initialize_hidden(self, batch_size):
        return torch.zeros(batch_size, self.hidden_size)
    
    def forward(self, x_in, initial_hidden=None):
        if self.batch_first:
            batch_size, seq_size, feat_size = x_in.size()
            x_in = x_in.permute(1, 0, 2)
        else:
            seq_size, batch_size, feat_size = x_in.size()
            
        hiddens = []
        
        if initial_hidden is None:
            initial_hidden = self._initialize_hidden(batch_size)
            initial_hidden = initial_hidden.to(x_in.device)
        
        hidden_t = initial_hidden
        
        for t in range(seq_size):
            hidden_t = self.rnn_cell(x_in[t], hidden_t)
            hiddens.append(hidden_t)
            
        hiddens = torch.stack(hiddens)
        
        # 입력 데이터의 첫 차원 batch_size였으므로
        # 출력 텐서도 같은 형식을 따르도록 변환
        if self.batch_first:
            hiddens = hiddens.permute(1, 0, 2)
            
        return hiddens

In [3]:
input_size = 512
hidden_size = 100

rnn = ElmanRNN(input_size, hidden_size)
print(rnn)

ElmanRNN(
  (rnn_cell): RNNCell(512, 100)
)
