In [4]:
import torch
import torch.nn as nn

In [200]:
class custom_GRU_cell(torch.nn.Module):

    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        
        # update gate
        self.linear_w_z = nn.Linear(self.input_dim, self.hidden_dim)
        self.linear_u_z = nn.Linear(self.hidden_dim, self.hidden_dim)
        self.activation_z = nn.Sigmoid()

        # reset gate
        self.linear_w_r = nn.Linear(self.input_dim, self.hidden_dim)
        self.linear_u_r = nn.Linear(self.hidden_dim, self.hidden_dim)
        self.activation_r = nn.Sigmoid()
        
        # output
        self.linear_w_h = nn.Linear(self.input_dim, self.hidden_dim)
        self.linear_u_h = nn.Linear(self.hidden_dim, self.hidden_dim)
        self.activation_h = nn.Tanh()
        
        
    def forward(self, x_t, h_prev):
        try:
            device = x_t.device
        except:
            device = 'cpu'
                
        output_z = self.activation_z(self.linear_w_z(x_t) + self.linear_u_z(h_prev))
        output_r = self.activation_r(self.linear_w_r(x_t) + self.linear_u_r(h_prev))
        hidden_hat = self.activation_h(self.linear_w_h(x_t) + torch.mul(output_r, self.linear_u_h(h_prev)))
        ones = torch.ones_like(output_z).to(device)
        hidden = torch.mul(output_z, h_prev) + torch.mul((ones - output_z), hidden_hat)
        
        return hidden
    
    
class custom_GRU(torch.nn.Module):

    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        
        self.cell = custom_GRU_cell(input_dim, hidden_dim)
    
    def forward(self, inputs):
        try:
            device = inputs.device
        except:
            device = 'cpu'

        outputs = []
        out_t = torch.zeros(inputs.shape[0], 1, self.hidden_dim)
        
        for t, x_t in enumerate(inputs.chunk(inputs.shape[1], dim=1)):
            out_t = self.cell(x_t, out_t)
            outputs.append(out_t.squeeze(1).detach().cpu())
        outputs = torch.stack(outputs, 1)
        return outputs, out_t.squeeze(1)

In [195]:
x = torch.rand(3, 5, 25)

In [196]:
model = custom_GRU(25, 8)