In [1]:
import torch
import torch.nn as nn

# Self-Attention
$$\text{Attention}(Q,K,V) = \text{softmax}(\frac{QK^T}{\sqrt{d_k}})V$$

where -
- $d_k$: Embedding size

# Multi-Head Self-Attention
<p align="center">
<img src="../images/Multi-Head Self Attention.png" style="width:450px;height:250px;">
</p>




In [41]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embedding_size, heads) -> None:
        """
            embedding_size: Dimension of embedding
            heads: Number of splits on the embedding
        """
        super(MultiHeadAttention, self).__init__()
        self.embedding_size = embedding_size
        self.heads = heads
        self.head_dimension = embedding_size//heads 
        # Example: 256 embedding dimension and 8 heads = Each head will have a dimension of 32
        
        # In case we end up sending a non-divisible value, we will end up with the last head having a different dimension, which we don't want
        assert (self.head_dimension*heads == embedding_size), "Embedding size needs to be divisible by the heads"

        # Define the linear layers to apply on the input to obtain Q, K and V
        self.W_Q = nn.Linear(self.head_dimension, self.head_dimension, bias = False) # W_Q
        self.W_K = nn.Linear(self.head_dimension, self.head_dimension, bias = False) # W_K
        self.W_V = nn.Linear(self.head_dimension, self.head_dimension, bias = False) # W_V
        self.fc_out = nn.Linear(heads*self.head_dimension, embedding_size) # Or embedding_size --> embedding_size

    def forward(self, query, keys, values, mask=None):
        # query = (batch_size, query_len, embedding_dim)
        # keys = (batch_size, key_len, embedding_dim)
        # values = (batch_size, value_len, embedding_dim)

        batch_size = query.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]
        # The length of Q, K and V will DEPEND on the target or source sentence length
        # Since here we don't know where it will be used, either for encoder or decoder, we cannot fix the length to be of the source or the target
        # So they will vary based on where it is used
        # We will think of the all the len variables as the sentence length

        # Split embedding into self.heads pieces
        query = query.reshape(batch_size, query_len, self.heads, self.head_dimension)
        keys = keys.reshape(batch_size, key_len, self.heads, self.head_dimension)
        values = values.reshape(batch_size, value_len, self.heads, self.head_dimension)

        query = self.W_Q(query)
        keys = self.W_K(keys)
        values = self.W_V(values)

        # New shapes -
        # query = (batch_size, query_len, heads, head_dim)
        # keys = (batch_size, key_len, heads, head_dim)
        # values = (batch_size, value_len, heads, head_dim)

        # Next step: Multiply query with keys and name it energy
        energy = torch.einsum("nqhd,nkhd->nhqk", query, keys)
        # n: Batch Size
        # q: Query length
        # k: Key length
        # h: Heads
        # d: Head dimension
        # energy shape: (batch_size, heads, query_len, key_len)

        if mask is not None: # Add mask
            energy = energy.masked_fill(mask == 0 , value = float("-1e20")) 
            # If mask == 0, then we want to shut that off
            # Mask for the target will be a TRIANGULAR MATRIX
            # The element when we will close it will be zero, so will replace it with -infinity
            # This will result in the softmax becoming 0

        # Attention(Q,K,V) = softmax(QK^T/sqrt(embedding_size))V
        attention = torch.softmax(energy/(self.embedding_size**(0.5)), dim = 3) # Normalizing across the key_length (can be source or target sentence)
        out = torch.einsum("nhql, nlhd -> nqhd", attention, values)
        # value_len and key_len will ALWAYS have the same length
        # query_len can be different based on if it is from the source or target length
        # attention shape: (batch_size, heads, query_len, key_len)
        # values shape: (batch_size, value_len, heads, heads_dim)
        # We want out shape to be: (batch_size, query_len, heads, head_dim)

        out = out.reshape(batch_size, query_len, self.heads*self.head_dimension) 
        # Concatenate back to embedding dimension (Flattening last two dimensions)
        out = self.fc_out(out)
        return out

In [None]:
torch.softmax(torch.tensor([[0,2][3,4]]))

### To understand how `masked_fill()` works

In [42]:
mask = torch.tril(torch.ones(5,5))
mask

tensor([[1., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0.],
        [1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1.]])

In [43]:
xy = torch.randn(5,5)
xy

tensor([[-0.6378,  1.1017, -1.1976, -0.3224, -2.5808],
        [ 1.3692,  1.3641,  0.2511,  0.6175, -1.0132],
        [-1.0865,  0.9378,  0.1768, -0.2069, -0.6674],
        [ 0.4720,  0.0201, -0.3909, -0.0814, -0.7656],
        [-1.0700,  1.9985,  1.7667,  0.6541,  1.3015]])

In [44]:
xy.masked_fill(mask==0, value = -10)

tensor([[ -0.6378, -10.0000, -10.0000, -10.0000, -10.0000],
        [  1.3692,   1.3641, -10.0000, -10.0000, -10.0000],
        [ -1.0865,   0.9378,   0.1768, -10.0000, -10.0000],
        [  0.4720,   0.0201,  -0.3909,  -0.0814, -10.0000],
        [ -1.0700,   1.9985,   1.7667,   0.6541,   1.3015]])

# Transformer Encoder Block
<p align="center">
<img src="../images/Transformer Encoder Block.png" style="width:200px;height:300px;">
</p>

This is the Encoder block which we shall implement in the `TransformerEncoderBlock` class

In [45]:
class TransformerEncoderBlock(nn.Module):
    def __init__(self, embedding_size, heads, dropout = 0.1, forward_expansion = 4) -> None:
        super(TransformerEncoderBlock, self).__init__()
        self.multiHeadAttention = MultiHeadAttention(embedding_size = embedding_size, heads = heads)
        self.norm1 = nn.LayerNorm(embedding_size)
        # BatchNorm: Takes average across batch
        # LayerNorm: Takes average across all examples
        self.norm2 = nn.LayerNorm(embedding_size)
        self.feed_forward = nn.Sequential(
            nn.Linear(embedding_size, forward_expansion*embedding_size),
            nn.ReLU(),
            nn.Linear(forward_expansion*embedding_size, embedding_size)
        )
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, query, keys, values, mask = None):
        attention = self.multiHeadAttention(query, keys, values, mask)
        x_mid = self.dropout(self.norm1(attention + query))
        forward = self.feed_forward(x_mid)
        out = self.dropout(self.norm2(forward + x_mid))
        return out

In [46]:
class Encoder(nn.Module):
    def __init__(
        self, 
        source_vocab_size, # We will create the embeddings as well
        embedding_size, # We will create the embeddings as well
        num_layers, # Number of Encoder Layers
        heads, 
        device, 
        forward_expansion, 
        dropout, 
        max_length 
        ) -> None:
        """
            max_length: Related to the positional embedding. We need to send
            in the max sentence length, for example, if we have most of the 
            sentences of length around 30 - 70 and a couple of sentences of
            length 1000, we will set the max_length of 100, which will remove
            the 1000 length sentences and keep the normal size of the data
            (Generally 100 depending on the data)
        """
        super(Encoder, self).__init__()
        self.embedding_size = embedding_size
        self.device = device

        # Trainable Embeddings
        self.word_embedding = nn.Embedding(num_embeddings=source_vocab_size, embedding_dim=embedding_size)
        self.position_embedding = nn.Embedding(max_length, embedding_size) # Trainable Positional Embeddings

        self.layers =nn.ModuleList(
            [
                TransformerEncoderBlock(
                    embedding_size = embedding_size,
                    heads = heads,
                    dropout = dropout,
                    forward_expansion = forward_expansion
                )
                for _ in range(num_layers) # Create num_layers objects of TransformerEncoderBlock
            ]
        )
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, mask):
        batch_size, sequence_length = x.shape # Input words
        positions = torch.arange(0, sequence_length).expand(batch_size, sequence_length).to(self.device)

        out = self.dropout(self.word_embedding(x) + self.position_embedding(positions))

        for layer in self.layers:
            out = layer(out, out, out, mask)
        
        return out

# Transformer Decoder Block
<p align="center">
<img src="../images/Transformer Decoder Block.png" style="width:200px;height:400px;">
</p>

This is the Decoder block which we shall implement in the `TransformerDecoderBlock` class

In [47]:
class TransformerDecoderBlock(nn.Module):
    def __init__(
        self, 
        embedding_size,
        heads, 
        forward_expansion, 
        dropout, 
        device
        ) -> None:
        super(TransformerDecoderBlock, self).__init__()
        self.attention = MultiHeadAttention(embedding_size, heads)
        self.norm = nn.LayerNorm(embedding_size)
        self.encoder_block = TransformerEncoderBlock(
            embedding_size,
            heads,
            dropout,
            forward_expansion
        )
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, value, key, src_mask, trg_mask):
        """ 
            trg mask: Triangular mask
            src mask: 
            Optional - Generally only used on the padded part of the sentence
        """
        attention = self.attention(x, x, x, trg_mask)
        query = self.dropout(self.norm(attention + x))
        out = self.encoder_block(query, key, value, src_mask)
        return out
        

# Transformer Decoder
<p align="center">
<img src="../images/Transformer Decoder.png" style="width:150px;height:450px;">
</p>

This is the Decoder which we shall implement in the `Decoder` class

In [75]:
class Decoder(nn.Module):
    def __init__(
        self,
        target_vocab_size,
        embedding_size,
        num_layers,
        heads,
        device,
        forward_expansion,
        dropout,
        max_length
        ) -> None:
        super(Decoder, self).__init__()
        self.device = device
        self.word_embedding = nn.Embedding(target_vocab_size, embedding_size)
        self.positional_embedding = nn.Embedding(max_length, embedding_size)

        self.layers = nn.ModuleList(
            [
                TransformerDecoderBlock(
                    embedding_size,
                    heads,
                    forward_expansion,
                    dropout,
                    device
                )
                for _ in range(num_layers) # Create num_layers objects of TransformerDecoderBlock
            ]
        )

        # Will get an input of shape: (batch_size, query_len, embedding_size)
        self.fc_out = nn.Linear(embedding_size, target_vocab_size)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, enc_out, src_mask, trg_mask):
        batch_size, sequence_length = x.shape # Input words
        positions = torch.arange(0, sequence_length).expand(batch_size, sequence_length).to(self.device)
        out = self.dropout(self.word_embedding(x) + self.positional_embedding(x))

        for layer in self.layers:
            out = layer(out, value = enc_out, key = enc_out, src_mask = src_mask, trg_mask = trg_mask)
        
        out = self.fc_out(out)
        return out

# Transformer

In [76]:
class Transformer(nn.Module):
    def __init__(
        self,
        source_vocab_size,
        target_vocab_size,
        src_pad_idx,
        trg_pad_idx,
        embedding_size = 256,
        num_layers = 6,
        forward_expansion = 4,
        heads = 8,
        dropout = 0.1,
        device = "cpu",
        max_length = 100
        ) -> None:
        super(Transformer, self).__init__()

        self.encoder = Encoder(
            source_vocab_size,
            embedding_size,
            num_layers,
            heads,
            device,
            forward_expansion,
            dropout,
            max_length
        )

        self.decoder = Decoder(
            target_vocab_size,
            embedding_size,
            num_layers,
            heads,
            device,
            forward_expansion,
            dropout,
            max_length
        )

        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.device = device

    # Make the functions to create the mask

    def make_src_mask(self, src):
        # src shape: (batch_size, src_len)
        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2) # Adding 2 dimensions
        # src_mask shape: (batch_size, 1, 1, src_len)
        # If it is a src pad index, it will be set to 0, else 1
        return src_mask.to(self.device)
    
    def make_trg_mask(self, trg):
        # Create a triangular matrix
        # We want a LOWER TRIANGULAR MATRIX with the lower part having 1s and rest 0s
        # We will use torch.tril() for creating a lower triangular matrix
        batch_size, trg_length = trg.shape
        trg_mask = torch.tril(torch.ones((trg_length, trg_length))).expand(batch_size, 1, trg_length, trg_length)
        # We shall expand it to obtain a mask for each training example
        return trg_mask.to(self.device)
    
    def forward(self, source, target):
        src_mask = self.make_src_mask(source)
        trg_mask = self.make_src_mask(target)
        enc_src = self.encoder(x = source, mask = src_mask)
        dec_out = self.decoder(x = target, enc_out = enc_src, src_mask = src_mask, trg_mask = trg_mask)
        return dec_out

### To understand `torch.tril()` and `torch.triu()`

In [50]:
print('Only ones - \n',torch.ones((5,5)))
print('\nLower Triangular Matrix - \n',torch.tril(torch.ones((5,5))))

Only ones - 
 tensor([[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]])

Lower Triangular Matrix - 
 tensor([[1., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0.],
        [1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1.]])


In [51]:
a = torch.randn((5,5))
print('Another Example with random numbers- \n', a)
print('\nLower Triangular version - \n',torch.tril(a))
print('\nUpper Triangular version - \n',torch.triu(a))

Another Example with random numbers- 
 tensor([[-0.0524, -1.5357, -0.9137,  0.7272, -0.0664],
        [-0.1379,  1.5173, -0.9358, -0.1186, -1.7448],
        [-1.7164, -0.0762, -0.1048, -1.0726,  0.1657],
        [-0.6807,  0.2704, -1.0409,  0.0774,  0.4895],
        [-0.2081,  0.7440, -0.9311,  0.0226,  0.6382]])

Lower Triangular version - 
 tensor([[-0.0524,  0.0000,  0.0000,  0.0000,  0.0000],
        [-0.1379,  1.5173,  0.0000,  0.0000,  0.0000],
        [-1.7164, -0.0762, -0.1048,  0.0000,  0.0000],
        [-0.6807,  0.2704, -1.0409,  0.0774,  0.0000],
        [-0.2081,  0.7440, -0.9311,  0.0226,  0.6382]])

Upper Triangular version - 
 tensor([[-0.0524, -1.5357, -0.9137,  0.7272, -0.0664],
        [ 0.0000,  1.5173, -0.9358, -0.1186, -1.7448],
        [ 0.0000,  0.0000, -0.1048, -1.0726,  0.1657],
        [ 0.0000,  0.0000,  0.0000,  0.0774,  0.4895],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.6382]])


## Example

In [83]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

x = torch.tensor([[1, 5, 6, 4, 3, 9, 5, 2, 0], [1, 8, 7, 3, 4, 5, 6, 7, 2]]).to(
    device
)
trg = torch.tensor([[1, 7, 4, 3, 5, 9, 2, 0], [1, 5, 6, 2, 4, 7, 6, 2]]).to(device)

src_pad_idx = 0
trg_pad_idx = 0
src_vocab_size = 10
trg_vocab_size = 10
model = Transformer(src_vocab_size, trg_vocab_size, src_pad_idx, trg_pad_idx, device=device).to(
    device
)
out = model(x, trg[:, :-1])
print(out.shape)

cpu
torch.Size([2, 7, 10])
