In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

In [2]:
class MyGRUCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size

        #Reset gate
        self.Wr = nn.Linear(input_size, hidden_size)
        self.Ur = nn.Linear(hidden_size, hidden_size, bias=False)

        # Update gate
        self.Wz = nn.Linear(input_size, hidden_size)
        self.Uz = nn.Linear(hidden_size, hidden_size, bias=False)

        # Candidate
        self.Wh = nn.Linear(input_size, hidden_size)
        self.Uh = nn.Linear(hidden_size, hidden_size, bias=False)

    def forward(self, x_t, h_prev):
        # Reset and update gates
        r = torch.sigmoid(self.Wr(x_t) + self.Ur(h_prev))
        z = torch.sigmoid(self.Ur(x_t) + self.Ur(h_prev))

        # Candidate hidden state
        h_tilde = torch.tanh(self.Wh(x_t) + self.Uh(r * h_prev))

        # Final hidden state
        h_t = ((1-z)*h_prev)+(z*h_tilde)
        return h_t


In [None]:
class MyGRU(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size
        self.cell = MyGRUCell(input_size, hidden_size)

    def forward(self, x):
        """
        x: (batch, seq_len, input_size)
        """
        

Create a toy dataset

In [4]:
def generate_data(N=200, seq_len=10):
    X, Y = [], []
    for _ in range(N):
        seq = torch.sin(torch.linspace(0, 3.14, seq_len+1)) + 0.1*torch.randn(seq_len+1)
        X.append(seq[:-1].unsqueeze(-1))
        Y.append(seq[1:].unsqueeze(-1))
        return torch.stack(X), torch.stack(Y)
X, Y = generate_data()
print("X: ", X)
print("Y: ", Y)


X:  tensor([[[-0.0403],
         [ 0.2180],
         [ 0.5806],
         [ 0.7972],
         [ 0.9806],
         [ 0.8348],
         [ 0.8140],
         [ 0.7307],
         [ 0.4437],
         [ 0.4677]]])
Y:  tensor([[[ 0.2180],
         [ 0.5806],
         [ 0.7972],
         [ 0.9806],
         [ 0.8348],
         [ 0.8140],
         [ 0.7307],
         [ 0.4437],
         [ 0.4677],
         [-0.1456]]])
