In [11]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F



In [12]:


def scaled_dot_product_attention(q, k, v, mask=None):
    """Calculate the attention weights. q, k, v must have matching leading dimensions.
    The mask has different shapes depending on its type(padding or look ahead) 
    but it must be broadcastable for addition.
    Args:
    q: query shape == (..., seq_len, depth)
    k: key shape == (..., seq_len, depth)
    v: value shape == (..., seq_len, depth_v)
    mask: Float tensor with shape broadcastable 
          to (..., seq_len_q, seq_len_k). Defaults to None.
    
    Returns:
    output, attention_weights
    """
    matmul_qk = torch.matmul(q, k.transpose(-2, -1))  # (..., seq_len_q, seq_len_k)
    
    # Scale matmul_qk
    depth = k.shape[-1]
    logits = matmul_qk / math.sqrt(depth)
    
    # Add the mask to the scaled tensor.
    if mask is not None:
        logits += (mask * -1e9)
    
    # Softmax is applied to the last axis (seq_len_k) so that the scores add up to 1.
    attention_weights = F.softmax(logits, dim=-1)  # (..., seq_len_q, seq_len_k)
    
    output = torch.matmul(attention_weights, v)  # (..., seq_len_q, depth_v)
    
    return output, attention_weights


In [13]:

# Assuming MultiHeadAttention class is defined as provided
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.depth = d_model // num_heads
        
        self.wq = nn.Linear(d_model, d_model)
        self.wk = nn.Linear(d_model, d_model)
        self.wv = nn.Linear(d_model, d_model)
        
        self.final_linear = nn.Linear(d_model, d_model)
    
    def split_into_heads(self, x, batch_size):
        x = x.view(batch_size, -1, self.num_heads, self.depth)
        return x.permute(0, 2, 1, 3)
    
    def forward(self, q, k, v, mask=None):
        batch_size = q.size(0)
        
        q = self.split_into_heads(self.wq(q), batch_size)
        k = self.split_into_heads(self.wk(k), batch_size)
        v = self.split_into_heads(self.wv(v), batch_size)
        
        # Scaled dot-product attention
        scaled_attention, attention_weights = scaled_dot_product_attention(q, k, v, mask)
        
        # Concatenation of heads
        scaled_attention = scaled_attention.permute(0, 2, 1, 3).contiguous()
        concat_attention = scaled_attention.view(batch_size, -1, self.d_model)
        
        # Final linear layer
        output = self.final_linear(concat_attention)
        
        return output, attention_weights

# Initialize parameters
d_model = 50
num_heads = 5
batch_size = 1
seq_len = 4

# Create the MultiHeadAttention layer
mha = MultiHeadAttention(d_model, num_heads)

# Create a random input tensor [batch_size, seq_len, d_model]
x = torch.randn(batch_size, seq_len, d_model)

print(x)

# Call the MultiHeadAttention layer
output, attn_weights = mha(x, x, x)  # q, k, v are typically the same in self-attention

print("Output shape:", output.shape)  # Should be [1, 4, 50]
print("Attention weights shape:", attn_weights.shape)  # Should be [1, num_heads, 4, 4]


Output shape: torch.Size([1, 4, 50])
Attention weights shape: torch.Size([1, 5, 4, 4])
