In [22]:
import torch
import d2l
import math

class PositionWiseFFN(torch.nn.Module):
    def __init__(self, input_size, hidden_size, output_size, **kwargs) -> None:
        super().__init__(**kwargs)
        self.dense1 = torch.nn.Linear(input_size, hidden_size)
        self.relu = torch.nn.ReLU()
        self.dense2 = torch.nn.Linear(hidden_size, output_size)
    
    def forward(self, x):
        return self.dense2(self.relu(self.dense1(x)))

In [23]:
ffn = PositionWiseFFN(4,4,8)
ffn.eval()
ffn(torch.ones((2,3,4)))[0]

tensor([[ 0.0455,  0.4076,  0.5875, -0.3784, -0.0044,  0.2528,  0.4329,  0.3892],
        [ 0.0455,  0.4076,  0.5875, -0.3784, -0.0044,  0.2528,  0.4329,  0.3892],
        [ 0.0455,  0.4076,  0.5875, -0.3784, -0.0044,  0.2528,  0.4329,  0.3892]],
       grad_fn=<SelectBackward0>)

In [24]:
ln = torch.nn.LayerNorm(2)
bn = torch.nn.BatchNorm1d(2)
x = torch.tensor([[1,2],[2,3]],dtype=torch.float32)
ln(x),bn(x)



(tensor([[-1.0000,  1.0000],
         [-1.0000,  1.0000]], grad_fn=<NativeLayerNormBackward0>),
 tensor([[-1.0000, -1.0000],
         [ 1.0000,  1.0000]], grad_fn=<NativeBatchNormBackward0>))

In [25]:
class AddNorm(torch.nn.Module):
    def __init__(self, normalized_shape, dropout, **kwargs) -> None:
        super().__init__(**kwargs)
        self.dropout = torch.nn.Dropout(dropout)
        self.ln = torch.nn.LayerNorm(normalized_shape)
    
    def forward(self,x,y):
        return self.ln(self.dropout(y)+x)

In [26]:
add_norm = AddNorm([3,4],0.5)
add_norm.eval()
add_norm(torch.ones(2,3,4),torch.ones(2,3,4))



tensor([[[0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]],

        [[0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]]], grad_fn=<NativeLayerNormBackward0>)

In [29]:
class TransformerEncoderBlock(torch.nn.Module):
    def __init__(self, hidden_size, num_heads, fnn_hidden_size, dropout, bias = False):
        super().__init__()
        self.attention = d2l.MultiHeadAttention(hidden_size, num_heads, dropout, bias)
        self.add_norm1 = AddNorm(hidden_size, dropout)
        self.fnn = PositionWiseFFN(hidden_size, fnn_hidden_size, hidden_size)
        self.add_norm1 = AddNorm(hidden_size, dropout)

    def forward(self, x, valid_lens):
        y = self.add_norm1(x, self.attention(x,x,x, valid_lens))
        return self.add_norm2(y, self.fnn(y))



In [36]:
#import importlib
#importlib.reload(d2l)

x = torch.ones((2, 100, 24))
valid_lens = torch.tensor([3, 2])
encoder_blk = TransformerEncoderBlock(24, 8, 48, 0.5)
encoder_blk.eval()
encoder_blk(x, valid_lens)

TypeError: TransformerEncoderBlock.forward() takes 2 positional arguments but 3 were given

In [31]:

class TransformerEncoder(torch.nn.Module):
    def __init__(self, vocab_size, hidden_size, num_heads, fnn_hidden_size, num_blocks, dropout, bias = False) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        self.embedding = torch.nn.Embedding(vocab_size, hidden_size)
        self.pos_encoding = d2l.PositionalEncoding(hidden_size, dropout)
        self.blocks = torch.nn.Sequential()
        for i in range(num_blocks):
            self.blocks.add_module("block"+str(i), TransformerEncoderBlock(hidden_size, num_heads, fnn_hidden_size,dropout,bias))
        
    def forward(self, x, valid_lens):
        x = self.pos_encoding(self.embedding(x)*math.sqrt(self.hidden_size))
        self.attention_weights = [None]*len(self.blocks)
        for i, block in enumerate(self.blocks):
            x = block(x, valid_lens)
            self.attention_weights[i] = block.attention.attention.attention_weights
        return x

