In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class DeepRnnHighway(nn.Module):
    def __init__(self, in_dim, h_dim, n_layers) -> None:
        super(DeepRnnHighway, self).__init__()

        self.in_dim = in_dim
        self.h_dim = h_dim
        self.n_layers = n_layers

        self.in2h = nn.Linear(in_dim, h_dim)
        self.W = [nn.Parameter(torch.rand(h_dim, h_dim)) for _ in range(n_layers)]
    
    def init(self) -> None:
        for w in self.W:
            nn.init.xavier_uniform_(w)

    def forward(self, x: torch.Tensor, h: torch.Tensor = None, batch_first: bool = True) -> torch.Tensor:
        if batch_first:
            x = x.transpose(0,1)
        seq_len, batch_size, _ = x.size()

        if h is None:
            h = torch.rand(self.n_layers, batch_size, self.h_dim)
        
        for j in range(seq_len):
            print(80*'=')
            h_prev = self.in2h(x[j])
            for i in range(self.n_layers):
                a = F.sigmoid(h[i] @ self.W[i] @ h_prev.T)
                g = F.sigmoid(h[i] * h_prev)
                h[i] = a*(1-g)*h[i] + (1-a)*g*h_prev
                h_prev = h[i]
                print(f'L-{i}: {a.data}')
        
        return h

In [None]:
model = DeepRnnHighway(128, 256, 3)
model.init()

h = None
x = torch.rand(1, 5, 128)

In [None]:
# x = torch.rand(1, 5, 128)
h = model(x, h)
print(h.shape)