In [13]:
from torch import nn
import numpy as np

In [14]:
# note, all dimensions actually include batch dimension
# for each tensor, the dimensions look more like (sequence_batch, d_model, etc...)
# as we do a batch for each sequence
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, n_heads: int, dropout: float) -> None:
        super(MultiHeadAttention, self).__init__()
        self.d_model = d_model # embedding vector size
        self.n_heads = n_heads # number of heads
        assert d_model % n_heads == 0, "embedding vector size (d_model) needs to be divisible by the nubmer of heads (n_heads)"
        self.d_k = d_model // n_heads # size of each head
        self.d_v = self.d_k # in attention is all you need, this is called d_v, but its equal to d_k

        # the 3 weight tensors of embedding vector size by embedding vector size
        self.w_q = nn.Linear(d_model, d_model) # query weight tensor
        self.w_k = nn.Linear(d_model, d_model) # key weight tensor
        self.w_v = nn.Linear(d_model, d_model) # value weight tensor

        # weight tensor to multiply by the concatenated heads at the end
        self.w_concat = nn.Linear(n_heads * self.d_v, d_model)

        # dropout
        self.Dropout = nn.Dropout(dropout)

    # Scaled Dot-Product Attention
    def attention(Q, K, V, mask, dropout):
        # K/keys has dimension (batch num, num heads, sequence length, d_k)
        # input V has dimension d_v (fairly certain its the same)
        # last element in keys dimensions
        d_k = K.shape[-1] 
        # query multiplied by transpose of last 2 dimensions of key
        # divided by the square root of d_k (temperature) 
        # helps to prevent dot products from growing in magnitude causing vanishing gradients
        attn = nn.matmul(Q / np.sqrt(d_k), K.transpose(2, 3))

        # masking
        if mask is not None:
            attn = attn.masked_fill(mask == 0, -1e9)
        
        attn = nn.Softmax(attn, dim = -1)

        # dropout
        if dropout is not None:
            attn = dropout(attn)

        output = nn.matmul(attn, V)
        # return output and attention scores
        return output, attn
        

    # Masking, replacing values we don't want to interact. Default is no mask.
    # If mask is applied softmax is applied puts them to 0. 
    # Hides attention of those tokens. Otherwise just gets values
    # for each token with each other token. 
    def forward(self, Q, K, V, mask=None):
        # multiply each tensor with its weight tensor
        query_tensor = self.w_q(Q)
        key_tensor = self.w_k(K)
        value_tensor = self.w_v(V)

        # split by the number of heads
        query_tensor = self.split(query_tensor)
        key_tensor = self.split(key_tensor)
        value_tensor = self.split(value_tensor)

        # get attention 
        output, self.attention_scores = self.attention(query_tensor, key_tensor, value_tensor, mask, self.Dropout)
        
        # concatenating, contiguous for in place in memory
        output = output.transpose(1, 2).contiguous().view(output.shape[0], -1, self.h * self.d_k)

        # multiply by output tensor
        return self.w_concat(output)
    
    # Helper function to do logic of splitting by number of heads
    def split(self, tensor):
        
        # dimensions go from (batch num, sequence length, d_model) to (batch num, number of heads, sequence legnth, size of each head)
        split_output = tensor.view(tensor.shape[0], tensor.shape[1], self.n_heads, self.d_k).transpose(1, 2)
        return split_output




Each layer in the encoder and decoder has a position-wise feed forward network. This is basically just 2 linear translations with ReLU between. The two linear transformations use different parameters per layer. Goes from d_model  -> d_ff -> d_model at the output.

In [15]:
class PositionFeedForward(nn.Module):
    def __init__(self, d_model: int, d_ff: int, dropout: float)->None:
        super(PositionFeedForward, self).__init__()
        
        #W1 and b1
        self.linear_1 = nn.Linear(d_model, d_ff) 
        self.dropout = nn.Dropout(dropout)

        #W2 and b2
        self.linear_2 = nn.Linear(d_ff, d_model) 

    # FFN(x) = max(0, xW1 + b1)W2 + b2
    def forward(self, x):
        return self.linear_2(self.dropout(nn.ReLU(self.linear_1(x))))
        


https://arxiv.org/pdf/1607.06450.pdf
Layer Normalization

Find the mean and std deviation, basically just valuing the sequence and features of one batch. 

In [16]:
class LayerNorm(nn.Module):
    def __init__(self, d_model, eeps: float = -1e-12):
        super(LayerNorm, self).__init__()
        self.eeps = eeps
        self.a = nn.Parameter(nn.ones(d_model))
        self.b = nn.Parameter(nn.zeros(d_model))
    
    def forward(self, x):
        mean = x.mean(dim = -1, keepdim = True)
        std = x.std(dim = -1, keepdim = True)
        return self.a * (x - mean) / (std + self.eeps) + self.b


Stack 6 layers, each with 2 sublayers. Everything comes from positional encoding, and goes into the multi-head self attention sublayer, which outputs into feed forward. Everything has dropout, and residual connections from each output and goes into add with layernorm of the outputs of the multi-head and feed forward. 

In [17]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, ff, n_heads, dropout) -> None:
        super(EncoderLayer, self).__init__()
        self.attention = MultiHeadAttention(d_model=d_model, n_heads=n_heads)
        self.norm1 = LayerNorm(d_model=d_model)
        self.dropout1 = nn.Dropout(dropout)

        self.feed_forward = PositionFeedForward(d_model=d_model, d_ff=ff, dropout=dropout)
        self.norm2 = LayerNorm(d_model=d_model)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x, src_mask):
        # multi attention

        # get residual
        residual = x
        x = self.attention(Q=x, K=x, V=x, mask=src_mask)

        # add & norm
        x = self.dropout1(x)
        x = self.norm1(x + residual)

        # get other residual
        residual = x
        x = self.feed_forward(x)

        # add & norm
        x = self.dropout2(x)
        return self.norm2(x + residual)


Loop through all Encoder layers, layernorm the end result

In [18]:
class Encoder(nn.Module):
    def __init__(self, layers, d_model, ff, n_heads, dropout) -> None:
        super().__init__()
        self.layers = nn.ModuleList([EncoderLayer(d_model=d_model,
                                                  ff=ff,
                                                  n_heads=n_heads,
                                                  dropout=dropout)
                                     for _ in range(layers)])
        self.norm = LayerNorm()

    def forward(self, x, src_mask):
        for layer in self.layers:
            x = layer(x, src_mask)
        return self.norm(x)



Similarly, has 6 stacks but with 3 sublayers. The main difference from encoder is the extra multi-head attention group. 

Query comes from output embedding/second positional encoding
Key and value comes from output of Encoder.

Each has residual connections added back in in the add & norm stage.

In [19]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, ff, n_heads, dropout):
        super(DecoderLayer, self).__init__()
        self.mask_attention = MultiHeadAttention(d_model=d_model, n_heads=n_heads)
        self.norm1 = LayerNorm(d_model=d_model)
        self.dropout1 = nn.Dropout(dropout)

        self.attention = MultiHeadAttention(d_model=d_model, n_heads=n_heads)
        self.norm2 = LayerNorm(d_model=d_model)
        self.dropout2 = nn.Dropout(dropout)
        
        self.feed_forward = PositionFeedForward(d_model=d_model, d_ff=ff, dropout=dropout)
        self.norm3 = LayerNorm(d_model=d_model)
        self.dropout3 = nn.Dropout(dropout)
    
    
    def forward(self, pos_out, encoder_out, encoder_mask, target_mask):
        # position encoding outputs as residual 
        residual = pos_out
        # put into attention block
        x = self.mask_attention(Q=pos_out, K=pos_out, V=pos_out, mask=target_mask)

        # add & norm
        x = self.dropout1(x)
        x = self.norm1(x + residual)
        
        # encoder decoder attention, in second multihead attention
        residual = x
        x = self.attention(Q=pos_out, K=encoder_out, V=encoder_out, mask=encoder_mask)

        # add & norm
        x = self.dropout2(x)
        x = self.norm2(x + residual)

        # feed forwards
        residual = x
        x = self.feed_forward(x)

        # add & norm
        x = self.dropout3(x)
        return self.norm3(x + residual)
        

Basically, loop through the decoder layers on each of their masks, and layernorm the result

In [20]:
class Decoder(nn.Module):
    def __init__(self, layers, d_model, ff, n_heads, dropout) -> None:
        super().__init__()
        self.layers = nn.ModuleList([DecoderLayer(d_model=d_model,
                                                  ff=ff,
                                                  n_heads=n_heads,
                                                  dropout=dropout)
                                     for _ in range(layers)])
        self.norm = LayerNorm()

    
    # target is decoder, source is encoder
    def forward(self, target, source, target_mask, source_mask):
        
        for layer in self.layers:
            target = layer(target, source, target_mask, source_mask)
        output = self.norm(target)
        return output
