In [None]:
import torch.nn as nn
import torch

In [None]:
class Attn(nn.Module):
    def __init__(self, dim_h, dim_s, dim_c):
        """
        dim_h: the number of features of each hidden layer of the encoder
        dim_s: the number of features of each hidden layer of the decoder
        dim_c: the number of features of the output from the combination of 
               the previous two vectors
        """
        super().__init__()

        self.dim_h = dim_h
        self.dim_s = dim_s
        self.dim_c = dim_c

        # The first layer deals with the matrix correspond to the hidden layers in the encoder
        self.w1 = nn.Linear(dim_h, dim_c, bias=False)

        # The second layer deals with the matrix correspond to the hidden layers in the decoder
        # Note that bias=True means it allows addition.
        self.w2 = nn.Linear(dim_s, dim_c, bias=True)

        # The third layer simply calculates the vector that converts the previous sum in to a vector
        # containing score of each pair of layers
        self.w3 = nn.Linear(dim_c, 1, bias=False)

        # The last layer just convert w3 into softmax values
        self.a_ij = nn.Softmax()

    def forward(self, hidden_encodes, hidden_decode):
        # Combine the term w1*encoders + w2*decoder
        comb = self.w1(hidden_encodes) + self.w2(hidden_decode)

        # Get the score values
        out = self.w3(comb)

        # Calculate the softmax value and multiply it by the hidden layers in the encoder
        context = torch.matmul(torch.transpose(self.a_ij(out), 0, 1), hidden_encodes)

        return context

In [None]:
# TESTING THE CORRECTNESS IN DIMENSION

dim_h = 30
dim_s = 20
dim_c = 40

attn = Attn(dim_h, dim_s, dim_c)

hidden_encodes = torch.rand(15, dim_h)
hidden_decode = torch.rand(1, dim_s)

attn(hidden_encodes, hidden_decode).size()




torch.Size([1, 30])