<a href="https://colab.research.google.com/github/simply-pouria/The-LMs-Book/blob/main/TheLMBook_Chapter3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Implementing an RNN

### Elman RNN Unit

In [4]:
import torch
import torch.nn as nn
class ElmanRNNUnit(nn.Module):
  def __init__(self, emb_dim):
    super().__init__()
    self.Uh = nn.Parameter(torch.randn(emb_dim, emb_dim))
    self.Wh = nn.Parameter(torch.randn(emb_dim, emb_dim))
    self.b = nn.Parameter(torch.zeros(emb_dim))

  def forward(self, x, h):
      return torch.tanh(x @ self.Wh + h @ self.Uh + self.b)

### Implementing the Elman RNN itself

In [3]:
class ElmanRNN(nn.Module):
    def __init__(self, emb_dim, num_layers):
        super().__init__()
        self.emb_dim = emb_dim
        self.num_layers = num_layers
        self.rnn_units = nn.ModuleList(
            [ElmanRNNUnit(emb_dim) for _ in range(num_layers)] )

    def forward(self, x):
      batch_size, seq_len, emb_dim = x.shape
      h_prev = [
          torch.zeros(batch_size, emb_dim, device=x.device)

          for _ in range(self.num_layers)
      ]
      outputs = []
      for t in range(seq_len):
          input_t = x[:, t]
          for l, rnn_unit in enumerate(self.rnn_units):
              h_new = rnn_unit(input_t, h_prev[l])
              h_prev[l] = h_new    # Update hidden state
              input_t = h_new      # Input for next layer
          outputs.append(input_t)  # Collect outputs
      return torch.stack(outputs, dim=1)

### RNN as a Language Model

In [None]:
class RecurrentLanguageModel(nn.Module):
    def __init__(self, vocab_size, emb_dim, num_layers, pad_idx):
        super().__init__()
        self.embedding = nn.Embedding(
            vocab_size,
            emb_dim,
            padding_idx= pad_idx
        )
        self.rnn = ElmanRNN(emb_dim, num_layers)
        self.fc = nn.Linear(emb_dim, vocab_size)

    def forward(self, x):
        embeddings = self.embedding(x)
        rnn_output = self.rnn(embeddings)
        logits = self.fc(rnn_output)
        return logits