## 엘만 RNN 구현하기

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class ElmanRNN(nn.Module):
    def __init__(self, input_size, hidden_size, batch_first=False):
        """
        매개변수:
            input_size (int): 입력 벡터 크기
            hidden_size (int): 은닉 상태 벡터 크기
            batch_first (bool): 0번째 차원이 배치인지 여부
        """
        super(ElmanRNN, self).__init__()
        self.rnn_cell = nn.RNNCell(input_size, hidden_size)
        
        self.batch_first = batch_first
        self.hidden_size = hidden_size
        
        
    def _initialize_hidden(self, batch_size):
        return torch.zeros(batch_size, self.hidden_size)
       
        
    def forward(self, x_in, initial_hidden=None):
        """ElmanRNN의 정방향 계산
        
        매개변수:
            x_in (torch.Tensor): 입력 데이터 텐서
                If self.batch_first: x_in.shape = (batch_size, seq_size, feat_size)
                Else: x_in.shape = (seq_size, batch_size, feat_size)
            initial_hidden (torch.Tensor): RNN의 초기 은닉 상태
        반환값:
            hiddens (torch.Tensor): 각 타임 스텝에서 RNN 출력
                If self.batch_first:
                    hiddens.shape = (batch_size, seq_size, hidden_size)
                Else: hiddens.shape = (seq_size, batch_size, hidden_size)
        """
        
        if self.batch_first:
            # TRY IT YOURSELF
            batch_size, seq_size, feat_size = x_in.size()
            x_in = x_in.permute(1, 0, 2)
        else:
            seq_size, batch_size, feat_size = x_in.size()
            
        hiddens = []
        
        if initial_hidden is None:
            initial_hidden = self._initialize_hidden(batch_size)
            initial_hidden = initial_hidden.to(x_in.device)
            
        hidden_t = initial_hidden
        
        for t in range(seq_size):# TRY IT YOURSELF
            # TRY IT YOURSELF
            hidden_t = self.rnn_cell(x_in[t], hidden_t)
            hiddens.append(hidden_t)
            
        # TRY IT YOURSELF
        hiddens = torch.stack(hiddens)
        
        if self.batch_first:
            hiddens = hiddens.permute(1, 0, 2)
            
        return hiddens

In [3]:
input_size = 512
hidden_size = 100

# 모델 생성
rnn = ElmanRNN(input_size, hidden_size)
print(rnn)

ElmanRNN(
  (rnn_cell): RNNCell(512, 100)
)
