In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Multi-Layer RNN

In [17]:
class RNNCell(nn.Module):
    def __init__(self, input_sz, hidden_sz):
        super().__init__()
        self.i2h = nn.Linear(input_sz, hidden_sz)
        self.h2h = nn.Linear(hidden_sz, hidden_sz)

    def forward(self, x, h_prev):
        hx = self.i2h(x)
        hh = self.h2h(h_prev)
        h = torch.tanh(hx + hh)
        return h

In [31]:
# a multi-layer RNN
class RNN(nn.Module):
    def __init__(self, input_sz, hidden_sz, n_layers):
        super().__init__()
        self.n_layers = n_layers
        self.hidden_sz = hidden_sz
        self.rnn = nn.ModuleList([RNNCell(input_sz, hidden_sz) if i == 0 else RNNCell(hidden_sz, hidden_sz)
            for i in range(n_layers)
        ])

    def forward(self, x):
        batch_sz, seq_len, _ = x.shape
        h = [torch.zeros(batch_sz, self.hidden_sz)for _ in range(n_layers)]
        outputs = []
        for t in range(seq_len):
            input_t = x[:, t, :]
            for layer in range(self.n_layers):
                h[layer] = self.rnn[layer](input_t, h[layer])
                input_t = h[layer] # output of current layer is input for next layer
            outputs.append(h[-1].unsqueeze(1)) # store last layer output
        outputs = torch.cat(outputs, dim=1)
        return outputs, h[-1] # outputs and final hidden state
                

In [62]:
# params
input_sz = 3 # size of input features
hidden_sz = 4 # size of hidden state
n_layers = 2 # number of layers of rnn cells
batch_sz = 2
seq_len = 5

# batch of input data (batch_size = 2, seq_length = 5, input_size = 3)
x = torch.randn(batch_sz, seq_len, input_sz)

# forward pass through RNN
rnn_model = RNN(input_sz, hidden_sz, n_layers)
output, h_final = rnn_model(x)
print(f"Output shape: {output.shape}")  # (2, 5, 4)
print(f"Final hidden state: {h_final}")

Output shape: torch.Size([2, 5, 4])
Final hidden state: tensor([[-0.6623,  0.7879,  0.6665,  0.4409],
        [-0.0420,  0.8941,  0.6779,  0.4650]], grad_fn=<TanhBackward0>)


# GRU 

In [57]:
class GRUCell(nn.Module):
    def __init__(self, input_sz, hidden_sz):
        super().__init__()
        self.Wr = nn.Linear(input_sz, hidden_sz, bias=True)
        self.Ur = nn.Linear(hidden_sz, hidden_sz, bias=False)
        self.Wz = nn.Linear(input_sz, hidden_sz, bias=True)
        self.Uz = nn.Linear(hidden_sz, hidden_sz, bias=False)
        self.Wh = nn.Linear(input_sz, hidden_sz, bias=True)
        self.Uh = nn.Linear(hidden_sz, hidden_sz, bias=False)

    def forward(self, x, h_prev):
        # compute reset gate
        r_x = self.Wr(x)
        r_h = self.Ur(h_prev)
        r = torch.sigmoid(r_x + r_h)

        # compute hidden gate
        z_x = self.Wz(x)
        z_h = self.Uz(h_prev)
        z = torch.sigmoid(z_x + z_h)

        # compute candidate hidden state
        h_x = self.Wh(x)
        h_h = self.Uh(h_prev)
        # with the candidate hidden (h_tilde) when rest gate is close to 0, previous hidden state is ignored
        # and cadidate_hidden is reset with h_x (or current input)
        h_tilde = torch.tanh(h_x + (r * h_h)) # different from paper's implementation but how pytorch does it

        # compute final hidden state
        # below update gate determines how much of candiate_hidden vs previous hidden state to return as new hidden state
        h = (z * h_prev) + ((1 - z) * h_tilde)
        return h

In [99]:
class GRU(nn.Module):
    def __init__(self, input_sz, hidden_sz, n_layers):
        super().__init__()
        self.hidden_sz = hidden_sz
        self.n_layers = n_layers
        self.gru = nn.ModuleList([GRUCell(input_sz, hidden_sz) if i == 0 else GRUCell(hidden_sz, hidden_sz)
            for i in range(self.n_layers)
        ])

    def forward(self, x):
        batch_sz, seq_len, _ = x.shape
        h = [torch.zeros(batch_sz, self.hidden_sz) for _ in range(self.n_layers)]
        # Process the input sequence step-by-step
        outputs = []
        for t in range(seq_length):
            input_t = x[:, t, :]  # Extract input at time t

            # Pass through each layer's GRU cell
            for layer in range(self.n_layers):
                h[layer] = self.gru[layer](input_t, h[layer])
                input_t = h[layer]  # Output of this layer becomes input for the next

            outputs.append(h[-1].unsqueeze(1))  # Store output for this step

        # Concatenate all outputs along the sequence dimension
        outputs = torch.cat(outputs, dim=1)
        return outputs


In [100]:
# input params
input_sz = 3 # size of features in input
hidden_sz = 5 # size of hidden states
seq_len = 4 # size of the sequence
n_layers = 2 # number of gru layers
batch_sz = 2 # number of sequences in a batch

# create input of (batch_sz, seq_len, input_sz)
x = torch.randn(batch_sz, seq_len, input_sz)

gru_model = GRU(input_sz, hidden_sz, n_layers)

# one forward pass
output = gru_model(x)


print(f"Output shape: {output.shape}")  #  Should print: (2, 4, 5)
print(f"Final hidden state: {h_final}")

Output shape: torch.Size([2, 4, 5])
Final hidden state: tensor([[-0.6623,  0.7879,  0.6665,  0.4409],
        [-0.0420,  0.8941,  0.6779,  0.4650]], grad_fn=<TanhBackward0>)


# Enocder-Decoder

In [103]:
class GRUEncoder(nn.Module):
    def __init__(self, input_sz, hidden_sz, n_layers):
        super().__init__()
        self.hidden_sz = hidden_sz
        self.n_layers = n_layers
        self.gru = nn.ModuleList([GRUCell(input_sz, hidden_sz) if i == 0 else GRUCell(hidden_sz, hidden_sz)
            for i in range(self.n_layers)
        ])
        self.V = nn.Linear(self.hidden_sz, self.hidden_sz)

    def forward(self, x):
        batch_sz, seq_len, _ = x.shape
        h = [torch.zeros(batch_sz, self.hidden_sz) for _ in range(self.n_layers)]
        # Process the input sequence step-by-step
        outputs = []
        for t in range(seq_length):
            input_t = x[:, t, :]  # Extract input at time t

            # Pass through each layer's GRU cell
            for layer in range(self.n_layers):
                h[layer] = self.gru[layer](input_t, h[layer])
                input_t = h[layer]  # Output of this layer becomes input for the next

            outputs.append(h[-1].unsqueeze(1))  # Store output for this step

        Vh = self.V(h[-1].unsqueeze(1))
        # c is the representation of the source phrase. It is computed with the hidden state at the N step
        # i.e., at the end of the source phrase
        c = torch.tanh(Vh)
        
        return c, h[-1]


In [104]:
class GRUDencoder(nn.Module):
    def __init__(self, input_sz, hidden_sz, n_layers):
        super().__init__()
        self.hidden_sz = hidden_sz
        self.n_layers = n_layers
        self.gru = nn.ModuleList([GRUCell(input_sz, hidden_sz) if i == 0 else GRUCell(hidden_sz, hidden_sz)
            for i in range(self.n_layers)
        ])
        self.V = nn.Linear(self.hidden_sz, self.hidden_sz)

    def forward(self, y, h):
        batch_sz, seq_len, _ = y.shape
        Vc = self.V(h)
        h_prime = torch.tanh(Vc)
        # Process the input sequence step-by-step
        all_logits = [] # prob of generating the next word
        for t in range(seq_length):
            input_t = y[:, t, :]  # Extract input at time t

            # Pass through each layer's GRU cell
            for layer in range(self.n_layers):
                h_prime[layer] = self.gru[layer](input_t, h_prime[layer])
                input_t = h_prime[layer]  # Output of this layer becomes input for the next
                raw_logits = h_prime[-1].unsqueeze(1)
                raw_logits = torch.relu(raw_logits) # use this forn ow instead of maxout

            all_logits.append(row_logits)  # Store output for this step
        logits = torch.vstack(all_logits, dim=1)
        probabilities = torch.softmax(logits)
        return probabilities