## Original transformer architecture as per the "Attention is all you need" paper

### Necessary imports

In [16]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import math
import copy

### Multi-head self attention

In [17]:
class MultiHeadAttention(nn.Module):
    def __init__(self,d_model,num_heads):
        super(MultiHeadAttention,self).__init__()
        
        """
        Args:
            d_model : dimension of the embedding vectors
            n_heads : number of self attention heads
            
        Return: N/A
                    
            
        Explanation:
            Why nn.Linear(d_model,d_model)? If d_model=512, that means a giant 512x512 weight matrix.
            If d_model=512, num_heads = 8, then d_head = 64. Shouldn't it be nn.Linear(d_head,d_model)?
            Then for an input I = 1x512, and weight matrix W = 64x512, doing I x W.transpose will give 1x64
            
            Reason is, I=1x512 with W=512x512 will give 1x512. This will be split into 8 parts,1 per head
            That would mean the vector fed to each head would be 1x64.
            So the computation is done one shot for efficiency, instead of multiply with 8 different matrices
            of size 64x512. 
            
        """
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_head = d_model // num_heads # Dimension of the Key, Query & Value vector passed to each head
        
        self.W_q = nn.Linear(d_model, d_model) # Query transformation
        self.W_k = nn.Linear(d_model, d_model) # Key transformation
        self.W_v = nn.Linear(d_model, d_model) # Value transformation
        self.W_o = nn.Linear(d_model, d_model) # Output transformation
    
    def scaled_dot_product_attention (self,Q,K,V,mask=None):
        """
        Args:
            Q,K,V : Query, Key & Value matrices. Dimension is (batch_size x num_heads x seq_length x d_head)
            mask : masking locations for the decoder
            
        Return:
            Z : Context vectors, from multiple self-attention blocks
                Dimenesion is (batch_size x num_heads x seq_length x d_head)
        """
        
        #Compute attention scores
        attn_scores = torch.matmul(Q, K.transpose(-2,-1)) / math.sqrt(self.d_head)
        
        #Apply mask if available
        if mask is not None:
            attn_scores = attn_scores.masked_fill(maks==0. -1e9)
            
        #Compute softmax, row-wise, i.e. for a given row all column values are fed to the softmax
        attn_probs = torch.softmax(attn_scores,dim=-1)
        
        #Compute output as summation of weighted V vectors
        Z = torch.matmul(attn_probs,V)
        return Z
    
    def split_heads(self,x):
        """
        Args:
            x : A tensor representing either Q,K or V, obtained after applying the corresponding W_q, W_k or
                W_v to the input batch. Size is (batch_size x seq_length x d_model)
        Return: A reordered tensor of size (batch_size x num_heads x seq_length x d_head), which was originally
                (batc_size x seq_length x num_heads x d_head). This was resized from the input.
                num_heads x d_head = d_model
        """
        batch_size, seq_length, d_model = x.size()
        return x.view (batch_size, seq_length, self.num_heads, self.d_head).transpose(1,2)
    
    def combine_heads(self,x):
        """
        Args:
            x : A tensor representing the computed context vectors 
                Dimensions are (batch_size x num_heads x seq_length x d_head)
        Return: This function does the reverse operation of split_heads. Therefore, it takes the input and 
                creates (batch_size x seq_length x d_model)
        """
        batch_size, _, seq_length,d_head = x.size()
        return x.transpose(1,2).contiguous().view(batch_size,seq_length,self.d_model)
    
    
    def forward(self,Q,K,V, mask=None):
        """
        Args:
            Q,K,V : All there parameters get a copy of the input batch as an argument.
                    Dimension of input batch is (batch_size x seq_length x d_model)
            mask : masking locations for the decoder
        Return:
            output : A tensor representing the output of the multi-head self-attention block
                    Dimension is (batch_size x seq_length x d_model)
        
        
        Explanation:
            Input is a batch of samples. 
            Assume batch_size = 32. Each sample is a sequence of tokens. Assume seq_length = 16. Each token
            will be converted to a embedding. Assume embedding size = 512. 
            Then, after the embedding is generated for a batch of 32 samples and the positional embeddigns are
            added, each batch has dimensions (32x16x512) 
            
            ========== Compute the Q, K & V vectors ============
            
            The weight matrices W_q, W_k & W_v are all 512x512 (see reason in __init__ method). 
            Q, K & V are computed by apply the corresponding weight matrices to a copy of the input batch. 
            The resultant matrix's dimensions remains same as the input batch at 32x16x512.
            
            Each matrix is then resized to (32x16x8x64). This means that for each of the 16 tokens, there are 
            8 vectors, each of them 64-D, which will be fed in parallel to the 8 attention heads.
            For efficient computation, each matrix is rearranged to (32x8x16x64). Now, for each of the 8 heads,
            there are 16 vectors (corresponding to 16 tokens in each sequence) of 64-D each.
            In other words, each of the 8 attention heads will receive a batch of 32 tensors, where each 
            tensor will consist of 16 vectors, each of dimension 64-D.
        
            ========== Compute multi-head self-attention context vectors =========
            Attention scores are computed by multiplying K = (32x8x16x64) by transpose of Q = (32x8x64x16) 
            Note that prior to transpose it was (32x8x16x64). Resultant matrix is 32x8x16x16
            This score is rescaled by dividing with sqrt(d_head), where d_head = 64. Next softmax is applied.
            Dimensions don't change during the rescaling and softmax
            Next the matrix is multiplied by the value matrix. So (32x8x16x16) x (32x8x16x64) -> (32x8x16x64).
            This multiplication perform 2 actions to compute the final context vector for each token,: 
            a) it multiplies the V vectors with the computed attention scores, and
            b) sums the vectors.
            This is done across all 8 heads in parallel, generating 8 context vectors (64-D) per token 
            
            ========== Compute combined output ===========
            The (32x8x16x64) is reversed back to (32x16x8x64) which is (batch_size x seq_length x num_heads
            x d_head). It is then rearranged so that for each token, the 8 vectors of 64-D dimension are 
            concatenated to form a 512 vector, resulting in (32x16x512). This is then passed through a liner
            layer with weight matrix 512x512. The final tensor is (32x16x512)
            
        """
    
        # Compute the Q, K & V vectors 
        Q = self.split_heads(self.W_q(Q))
        K = self.split_heads(self.W_k(K))
        V = self.split_heads(self.W_v(V))
        
        # Compute multi-head self-attention context vectors
        attn_ctxt_vecs = self.scaled_dot_product_attention(Q,K,V, mask)
        
        # Compute combined output
        output = self.W_o(self.combine_heads(attn_ctxt_vecs))
        
        return(output)
           