In [2]:
import math
import torch
import d2l

class PositionWiseFFN(torch.nn.Module):
    def __init__(self, input_size, hidden_size, output_size, **kwargs) -> None:
        super().__init__(**kwargs)
        self.dense1 = torch.nn.Linear(input_size, hidden_size)
        self.relu = torch.nn.ReLU()
        self.dense2 = torch.nn.Linear(hidden_size, output_size)
    
    def forward(self, x):
        return self.dense2(self.relu(self.dense1(x)))

In [23]:
ffn = PositionWiseFFN(4,4,8)
ffn.eval()
ffn(torch.ones((2,3,4)))[0]

tensor([[ 0.3742, -0.0382,  0.5203,  0.3202, -0.1715,  0.6132,  0.1372, -0.8487],
        [ 0.3742, -0.0382,  0.5203,  0.3202, -0.1715,  0.6132,  0.1372, -0.8487],
        [ 0.3742, -0.0382,  0.5203,  0.3202, -0.1715,  0.6132,  0.1372, -0.8487]],
       grad_fn=<SelectBackward0>)

In [3]:
ln = torch.nn.LayerNorm(2)
bn = torch.nn.BatchNorm1d(2)
x = torch.tensor([[1,2],[2,3]],dtype=torch.float32)
ln(x),bn(x)



(tensor([[-1.0000,  1.0000],
         [-1.0000,  1.0000]], grad_fn=<NativeLayerNormBackward0>),
 tensor([[-1.0000, -1.0000],
         [ 1.0000,  1.0000]], grad_fn=<NativeBatchNormBackward0>))

In [3]:
class AddNorm(torch.nn.Module):
    def __init__(self, normalized_shape, dropout, **kwargs) -> None:
        super().__init__(**kwargs)
        self.dropout = torch.nn.Dropout(dropout)
        self.ln = torch.nn.LayerNorm(normalized_shape)
    
    def forward(self,x,y):
        return self.ln(self.dropout(y)+x)

In [5]:
add_norm = AddNorm([3,4],0.5)
add_norm.eval()
add_norm(torch.ones(2,3,4),torch.ones(2,3,4))



tensor([[[0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]],

        [[0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]]], grad_fn=<NativeLayerNormBackward0>)

In [9]:
class TransformerEncoderBlock(torch.nn.Module):
    def __init__(self, hidden_size, num_heads, fnn_hidden_size, dropout, bias = False):
        super().__init__()
        self.attention = d2l.MultiHeadAttention(hidden_size, num_heads, dropout, bias)
        self.add_norm1 = AddNorm(hidden_size, dropout)
        self.fnn = PositionWiseFFN(hidden_size, fnn_hidden_size, hidden_size)
        self.add_norm2 = AddNorm(hidden_size, dropout)

    def forward(self, x, valid_lens):
        y = self.add_norm1(x, self.attention(x,x,x, valid_lens))
        return self.add_norm2(y, self.fnn(y))

In [None]:
x = torch.ones((2, 100, 24))
valid_lens = torch.tensor([3, 2])
encoder_blk = TransformerEncoderBlock(24, 8, 48, 0.5)
encoder_blk.eval()
encoder_blk(x, valid_lens)

In [13]:

class TransformerEncoder(torch.nn.Module):
    def __init__(self, vocab_size, hidden_size, num_heads, fnn_hidden_size, num_blocks, dropout, bias = False) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        self.embedding = torch.nn.Embedding(vocab_size, hidden_size)
        self.pos_encoding = d2l.PositionalEncoding(hidden_size, dropout)
        self.blocks = torch.nn.Sequential()
        for i in range(num_blocks):
            self.blocks.add_module("block"+str(i), TransformerEncoderBlock(hidden_size, num_heads, fnn_hidden_size,dropout,bias))
        
    def forward(self, x, valid_lens=None):
        x = self.pos_encoding(self.embedding(x)*math.sqrt(self.hidden_size))
        self.attention_weights = [None]*len(self.blocks)
        for i, block in enumerate(self.blocks):
            x = block(x, valid_lens)
            self.attention_weights[i] = block.attention.attention.attention_weights
        return x


In [15]:
encoder = TransformerEncoder(200, 24, 8, 48, 2, 0.5)
encoder(torch.ones((2,100),dtype=torch.long)).shape

torch.Size([2, 100, 24])

In [4]:
class TransformerDecoderBlock(torch.nn.Module):
    def __init__(self, hidden_size, num_heads, fnn_hidden_size, dropout, i) -> None:
        super().__init__()
        self.i=i
        self.attention1 = d2l.MultiHeadAttention(hidden_size,num_heads,dropout)
        self.add_norm1 = AddNorm(hidden_size, dropout)
        self.attention2 = d2l.MultiHeadAttention(hidden_size,num_heads,dropout)
        self.add_norm2 = AddNorm(hidden_size, dropout)
        self.fnn = PositionWiseFFN(hidden_size, fnn_hidden_size, hidden_size)
        self.add_norm3 = AddNorm(hidden_size, dropout)

    def forward(self, x, state):
        enc_outputs, enc_valid_lens = state[0], state[1]
        if state[2][self.i] is None:
            key_values = x
        else:
            key_values = torch.cat((state[2][self.i],x),dim=1)
        state[2][self.i] = key_values

        if self.training:
            batch_size, num_steps,_ = x.shape
            dec_valid_lens = torch.arange(1, num_steps+1,device=x.device).repeat(batch_size, 1)
        else:
            dec_valid_lens = None
        
        x2 = self.attention1(x,key_values,key_values,dec_valid_lens)
        y = self.add_norm1(x,x2)
        y2 = self.attention2(y,enc_outputs,enc_outputs,enc_valid_lens)
        z = self.add_norm2(y,y2)
        return self.add_norm3(z, self.fnn(z)), state
    

In [5]:
class TransformerDecoder(torch.nn.Module):
    def __init__(self, vocab_size, hidden_size, num_heads, fnn_hidden_size, num_blocks, dropout):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_blocks = num_blocks
        self.embedding = torch.nn.Embedding(vocab_size, hidden_size)
        self.pos_encoding = d2l.PositionalEncoding(hidden_size, dropout)
        self.blocks = torch.nn.Sequential()
        for i in range(num_blocks):
            self.blocks.add_module("block"+str(i), TransformerDecoderBlock(hidden_size,num_heads,fnn_hidden_size,dropout,i))
        self.dense = torch.nn.Linear(hidden_size, vocab_size)

    def init_state(self, enc_outputs, enc_valid_lens):
        return [enc_outputs, enc_valid_lens, [None]*self.num_blocks]
    
    def forward(self, x, state):
        x = self.pos_encoding(self.embedding(x)*math.sqrt(self.hidden_size))
        self.attention_weights=[[None]*self.num_blocks for _ in range(2)]
        for i, block in enumerate(self.blocks):
            x, state = block(x, state)
            self.attention_weights[0][i] = block.attention1.attention.attention_weights
            self.attention_weights[1][i] = block.attention2.attention.attention_weights
        return self.dense(x), state

        