In [11]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F



In [12]:


def scaled_dot_product_attention(q, k, v, mask=None):
    """Calculate the attention weights. q, k, v must have matching leading dimensions.
    The mask has different shapes depending on its type(padding or look ahead) 
    but it must be broadcastable for addition.
    Args:
    q: query shape == (..., seq_len, depth)
    k: key shape == (..., seq_len, depth)
    v: value shape == (..., seq_len, depth_v)
    mask: Float tensor with shape broadcastable 
          to (..., seq_len_q, seq_len_k). Defaults to None.
    
    Returns:
    output, attention_weights
    """
    matmul_qk = torch.matmul(q, k.transpose(-2, -1))  # (..., seq_len_q, seq_len_k)
    
    # Scale matmul_qk
    depth = k.shape[-1]
    logits = matmul_qk / math.sqrt(depth)
    
    # Add the mask to the scaled tensor.
    if mask is not None:
        logits += (mask * -1e9)
    
    # Softmax is applied to the last axis (seq_len_k) so that the scores add up to 1.
    attention_weights = F.softmax(logits, dim=-1)  # (..., seq_len_q, seq_len_k)
    
    output = torch.matmul(attention_weights, v)  # (..., seq_len_q, depth_v)
    
    return output, attention_weights


In [14]:

# Assuming MultiHeadAttention class is defined as provided
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.depth = d_model // num_heads
        
        self.wq = nn.Linear(d_model, d_model)
        self.wk = nn.Linear(d_model, d_model)
        self.wv = nn.Linear(d_model, d_model)
        
        self.final_linear = nn.Linear(d_model, d_model)
    
    def split_into_heads(self, x, batch_size):
        x = x.view(batch_size, -1, self.num_heads, self.depth)
        return x.permute(0, 2, 1, 3)
    
    def forward(self, q, k, v, mask=None):
        batch_size = q.size(0)
        
        q = self.split_into_heads(self.wq(q), batch_size)
        k = self.split_into_heads(self.wk(k), batch_size)
        v = self.split_into_heads(self.wv(v), batch_size)
        
        # Scaled dot-product attention
        scaled_attention, attention_weights = scaled_dot_product_attention(q, k, v, mask)
        
        # Concatenation of heads
        scaled_attention = scaled_attention.permute(0, 2, 1, 3).contiguous()
        concat_attention = scaled_attention.view(batch_size, -1, self.d_model)
        
        # Final linear layer
        output = self.final_linear(concat_attention)
        
        return output, attention_weights

# Initialize parameters
d_model = 50
num_heads = 5
batch_size = 1
seq_len = 4

# Create the MultiHeadAttention layer
mha = MultiHeadAttention(d_model, num_heads)

# Create a random input tensor [batch_size, seq_len, d_model]
x = torch.randn(batch_size, seq_len, d_model)


# Call the MultiHeadAttention layer
output, attn_weights = mha(x, x, x)  # q, k, v are typically the same in self-attention

print("Output shape:", output.shape)  # Should be [1, 4, 50]
print("Attention weights shape:", attn_weights.shape)  # Should be [1, num_heads, 4, 4]


tensor([[[-0.2434,  0.2149, -1.3231,  0.4357,  0.5050, -1.8991,  1.5277,
           0.1787, -1.6589,  1.4410, -0.8033, -0.8834,  1.6311, -1.0334,
          -1.3565,  0.8466, -0.1715, -0.2217, -0.3507,  0.1953, -0.4026,
           0.6221,  0.0838,  1.0236,  0.1254, -0.1957, -0.4559,  0.7202,
           0.2628, -0.5253, -0.9144, -0.7615, -1.0396, -1.0524,  0.8715,
           0.4677, -0.7454, -0.0298, -0.2609,  0.6635, -0.8599, -0.9492,
          -0.9492,  1.7086,  1.8027,  0.7773,  1.8323, -0.6181,  0.5029,
           0.0568],
         [-1.2629,  0.4335, -0.9180,  1.1268, -0.9282, -0.4077, -0.1635,
          -0.0834, -0.7019, -0.9074,  0.2541, -1.3722, -0.0398, -1.4604,
           0.2647, -0.7912, -0.2016,  2.2284, -0.5283, -1.0036, -1.4041,
          -0.0665,  0.8386, -0.6801, -1.0920, -0.6169, -0.0969, -0.0285,
           0.4917,  1.4354,  0.1649,  0.0543, -0.3693,  0.2572, -1.6187,
          -0.5503, -0.3965, -0.8212, -1.1357,  1.7308,  0.3621,  0.4882,
          -0.0255, -0.5075,  0.