In [None]:
import math

import torch
import torch.nn as nn
from torch.nn.functional import softmax

In [None]:
# try to finish this function on your own
def scaled_dot_product_attention(query, key, value, mask=None):
    """
    Args:
        query: (batch_size, num_heads, seq_len_q, d_k)
        key: (batch_size, num_heads, seq_len_k, d_k)
        value: (batch_size, num_heads, seq_len_v, d_v)
        mask: Optional mask to prevent attention to certain positions
    """
    # get the size of d_k using the query or the key
    
    # calculate the attention score using the formula given. Be vary of the dimension of Q and K. And what you need to transpose to achieve the desired results.

    #YOUR CODE HERE
    d_k = query.shape[-1]

    scores = torch.matmul(query, key.transpose(2, 3)) / math.sqrt(d_k)

    # hint 1: batch_size and num_heads should not change
    # hint 2: nXm @ mXn -> nXn, but you cannot do nXm @ nXm, the right dimension of the left matrix should match the left dimension of the right matrix. The easy way I visualize it is as, who face each other must be same

    # add inf is a mask is given, This is used for the decoder layer. You can use help for this if you want to. I did!!
    #YOUR CODE HERE
    if mask is not None:
        scores.masked_fill(mask == 0, float('-inf'))

    # get the attention weights by taking a softmax on the scores, again be wary of the dimensions. You do not want to take softmax of batch_size or num_heads. Only of the values. How can you do that?
    #YOUR CODE HERE
    attention_weights = softmax(scores, dim=-1)

    # return the attention by multiplying the attention weights with the Value (V)
    #YOUR CODE HERE
    return torch.matmul(attention_weights, value)


In [None]:
class MultiHeadAttention(nn.Module):
    #Let me write the initializer just for this class, so you get an idea of how it needs to be done
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads" #think why?

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads  # Note: use integer division //

        # Create the learnable projection matrices
        self.W_q = nn.Linear(d_model, d_model) #think why we are doing from d_model -> d_model
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    @staticmethod
    def scaled_dot_product_attention(query, key, value, mask=None):
        #YOUR IMPLEMENTATION HERE
        d_k = query.shape[-1]

        scores = torch.matmul(query, key.transpose(2, 3)) / math.sqrt(d_k)
        if mask is not None:
            scores.masked_fill(mask == 0, float('-inf'))
        
        attention_weights = softmax(scores, dim=-1)
        return torch.matmul(attention_weights, value)
        
    def forward(self, query, key, value, mask=None):
        #get batch_size and sequence length
        #YOUR CODE HERE
        batch_size = query.shape[0]
        seq_len = query.shape[2]

        # 1. Linear projections
        #YOUR CODE HERE
        Q = self.W_q(query)
        K = self.W_k(key)
        V = self.W_v(value)

        # 2. Split into heads
        #YOUR CODE HERE
        Q = Q.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)

        # 3. Apply attention
        #YOUR CODE HERE
        output = scaled_dot_product_attention(Q, K, V, mask)

        # 4. Concatenate heads
        #YOUR CODE HERE
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)

        # 5. Final projection
        #YOUR CODE HERE
        return self.W_o(output)
        