In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class BasicAttention(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.W_query = nn.Linear(embed_dim, embed_dim)
        self.W_key = nn.Linear(embed_dim, embed_dim)
        self.W_value = nn.Linear(embed_dim, embed_dim)
        
    def forward(self, x):
        # Apply linear transformations to get queries, keys, and values
        q = self.W_query(x)  # [B, T, C]
        k = self.W_key(x)    # [B, T, C]
        v = self.W_value(x)  # [B, T, C]
        
        # Calculate attention scores
        scores = torch.bmm(q, k.transpose(1, 2))  # [B, T, T]
        scores = scores / (k.size(-1) ** 0.5)     # Scaling by sqrt(d_k)
        
        # Apply softmax to get attention weights
        attn_weights = F.softmax(scores, dim=-1)  # [B, T, T]
        
        # Multiply attention weights with values
        out = torch.bmm(attn_weights, v)          # [B, T, C]
        
        return out, attn_weights

# testing
x = torch.randn(2, 5, 10)
attention = BasicAttention(10)
out, attn_weights = attention(x)
print(out.size())  # torch.Size([2, 5, 10])

torch.Size([2, 5, 10])
