In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as function

In [2]:
class SelfAttention(nn.Module):
    def __init__(self, emb_dim):
        super(SelfAttention, self).__init__()
        self.query = nn.Linear(emb_dim, emb_dim)
        self.key = nn.Linear(emb_dim, emb_dim)
        self.value = nn.Linear(emb_dim, emb_dim)
        self.emb_dim = emb_dim
        
    def forward(self, x):
        q = self.query(x)
        k = self.query(x)
        v = self.query(x)
        attention_weights = torch.matmul(q, k.T) / np.sqrt(self.emb_dim)
        attention_weights = function.softmax(attention_weights, dim=-1)
        attention_values = torch.matmul(attention_weights, v)
        return attention_values

In [3]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, emb_dim, num_heads):
        super(MultiHeadSelfAttention, self).__init__()
        self.emb_dim = emb_dim
        self.num_heads = num_heads
        self.head_dim = emb_dim // num_heads
        
        self.query = nn.Linear(emb_dim, emb_dim)
        self.key = nn.Linear(emb_dim, emb_dim)
        self.value = nn.Linear(emb_dim, emb_dim)
        self.fc = nn.Linear(emb_dim, emb_dim)
        self.norm = nn.LayerNorm(emb_dim, emb_dim)
        
    def forward(self, query, key, value, mask=None):
        seq_len, emb_size = query.size()
        q = self.query(query).view(seq_len, 2, 5).transpose(0,1)
        k = self.key(key).view(seq_len, 2, 5).transpose(0,1)
        v = self.value(value).view(seq_len, 2, 5).transpose(0,1)
        attention_weights = torch.matmul(q, k.transpose(-2, -1)) / np.sqrt(self.head_dim)
        if mask is not None:
            attention_weights = attention_weights.masked_fill(mask, -1e9)
        attention_weights = function.softmax(attention_weights, dim=-1)
        attention_values = torch.matmul(attention_weights, v)
        attention_values = attention_values.transpose(0, 1).contiguous().view(seq_len, emb_size)
        out = self.fc(attention_values) + value
        out = self.norm(out)
        return out

In [4]:
class FFNLayer(nn.Module):
    def __init__(self, emb_dim, ffn_dim):
        super(FFNLayer, self).__init__()
        self.linear1 = nn.Linear(emb_dim, ffn_dim)
        self.linear2 = nn.Linear(ffn_dim, emb_dim)
        
    def forward(self, x):
        x = self.linear1(x)
        x = function.relu(x)
        x = self.linear2(x)
        return x

In [5]:
class Encoder(nn.Module):
    def __init__(self, emb_dim, num_heads, ffn_dim, num_blocks):
        super(Encoder,self).__init__()
        self.num_blocks = num_blocks
        self.attention = attention = MultiHeadSelfAttention(emb_dim, num_heads)
        self.ffn = FFNLayer(emb_dim, ffn_dim)
        
    def forward(self, x, mask=None):
        # pos_emb = get_pos_emb(x)
        # x = x + pos_emb
        for _ in range(self.num_blocks):
            x = self.attention(x, x, x, None)
            x = self.ffn(x)
        return x

In [6]:
class Decoder(nn.Module):
    def __init__(self, emb_dim, num_heads, ffn_dim, num_blocks):
        super(Decoder,self).__init__()
        self.num_blocks = num_blocks
        self.attention1 = MultiHeadSelfAttention(emb_dim, num_heads)
        self.attention2 = MultiHeadSelfAttention(emb_dim, num_heads)
        self.ffn = FFNLayer(emb_dim, ffn_dim)
        
    def forward(self, x, enc_out, mask=None):
        # pos_emb = get_pos_emb(x)
        # x = x + pos_emb
        for _ in range(self.num_blocks):
            x = self.attention1(x, x, x)
            x = self.attention2(enc_out, enc_out, x)
            x = self.ffn(x)
        return x

In [7]:
embedding = torch.randint(1, 10, size=(4, 10)).float()
embedding

tensor([[3., 6., 6., 6., 3., 2., 7., 4., 8., 8.],
        [7., 5., 2., 3., 1., 8., 6., 9., 5., 2.],
        [4., 1., 4., 7., 5., 7., 7., 7., 3., 3.],
        [9., 5., 1., 2., 5., 4., 5., 7., 5., 7.]])

In [8]:
encoder = Encoder(10, 2, 5, 2)

In [9]:
out = encoder(embedding, None)
out

tensor([[ 0.2201, -0.3998, -0.1858, -0.2635,  0.1850,  0.1228, -0.3708,  0.0804,
         -0.1740,  0.2693],
        [ 0.2205, -0.3996, -0.1857, -0.2649,  0.1835,  0.1241, -0.3695,  0.0801,
         -0.1737,  0.2690],
        [ 0.2241, -0.3981, -0.1851, -0.2759,  0.1717,  0.1345, -0.3594,  0.0777,
         -0.1715,  0.2671],
        [ 0.2198, -0.4043, -0.1846, -0.2533,  0.1984,  0.1122, -0.3773,  0.0848,
         -0.1770,  0.2694]], grad_fn=<AddmmBackward0>)

In [10]:
decoder = Decoder(10, 2, 5, 3)

In [12]:
out = decoder(embedding, out)
out

tensor([[ 0.0602, -0.1558,  0.0338, -0.0008,  0.2352,  0.4724,  0.3739, -0.1169,
          0.1715,  0.0426],
        [ 0.0600, -0.1558,  0.0335, -0.0008,  0.2354,  0.4723,  0.3737, -0.1169,
          0.1714,  0.0428],
        [ 0.0600, -0.1558,  0.0336, -0.0008,  0.2353,  0.4723,  0.3738, -0.1170,
          0.1715,  0.0427],
        [ 0.0601, -0.1558,  0.0337, -0.0008,  0.2353,  0.4723,  0.3738, -0.1169,
          0.1715,  0.0427]], grad_fn=<AddmmBackward0>)