In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [None]:
class SelfAttention(nn.Module):
    """
    Self-Attention module where a sequence attends to itself.

    This is a fundamental building block in transformer-based architectures
    where each position in a sequence can attend to all positions in the same sequence.
    """

    def __init__(self, embedding_dim, dropout=0.1):
        """
        Initialize the self-attention module.

        Args:
            embedding_dim (int): The dimension of the input embeddings
            dropout (float): Dropout probability
        """
        super(SelfAttention, self).__init__()

        self.embedding_dim = embedding_dim
        self.scale = math.sqrt(embedding_dim)

        # Linear projections for Q, K, V
        self.query = nn.Linear(embedding_dim, embedding_dim)
        self.key = nn.Linear(embedding_dim, embedding_dim)
        self.value = nn.Linear(embedding_dim, embedding_dim)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        """
        Forward pass for self-attention

        Args:
            x: Input tensor (batch_size, seq_len, embedding_dim)
            mask: Optional mask tensor for masking out certain positions

        Returns:
            output: Self-attention output (batch_size, seq_len, embedding_dim)
            attention: Attention weights (batch_size, seq_len, seq_len)
        """
        batch_size, seq_len, _ = x.size()

        # Linear projections
        q = self.query(x)  # (batch_size, seq_len, embedding_dim)
        k = self.key(x)    # (batch_size, seq_len, embedding_dim)
        v = self.value(x)  # (batch_size, seq_len, embedding_dim)

        # Compute attention scores
        # (batch_size, seq_len, embedding_dim) @ (batch_size, embedding_dim, seq_len)
        # -> (batch_size, seq_len, seq_len)
        scores = torch.bmm(q, k.transpose(1, 2)) / self.scale

        # Apply mask if provided
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        # Apply softmax to get attention weights
        attention = F.softmax(scores, dim=-1)
        attention = self.dropout(attention)

        # Apply attention weights to values
        # (batch_size, seq_len, seq_len) @ (batch_size, seq_len, embedding_dim)
        # -> (batch_size, seq_len, embedding_dim)
        output = torch.bmm(attention, v)

        return output, attention

In [None]:
# Example usage
def test_self_attention():
    batch_size = 4
    seq_len = 10
    embedding_dim = 64

    # Create random input tensor
    x = torch.randn(batch_size, seq_len, embedding_dim)


    # Initialize self-attention modules
    self_attn = SelfAttention(embedding_dim)

    # Forward passes
    output1, attention1 = self_attn(x)

    print(f"Input shape: {x.shape}")
    print("Our Input tensor is --> ")
    print(x)

    print(f"Self-attention output shape: {output1.shape}")
    print("Our Output Context matrix  is --> ")
    print(output1)

    print(f"Self-attention weights shape: {attention1.shape}")
    print("Our Output Attention matrix  is --> ")
    print(attention1)

    return output1, attention1

if __name__ == "__main__":
    test_self_attention()

Input shape: torch.Size([4, 10, 64])
Our Input tensor is --> 
tensor([[[-0.8841,  0.0142, -0.0912,  ..., -0.6959,  0.0986,  0.9691],
         [-0.0103,  0.9237,  1.1994,  ..., -0.0157,  0.2500,  0.7310],
         [ 0.6077,  0.2446,  1.0081,  ..., -0.4376,  1.1936,  0.2615],
         ...,
         [ 1.0289,  0.9542,  0.3406,  ..., -2.0422, -0.6019,  0.2761],
         [ 1.6400,  0.4933, -0.8890,  ...,  1.2761,  0.8840,  1.1745],
         [-0.4358, -0.3111,  1.1548,  ...,  0.5248,  1.1690, -1.0264]],

        [[-0.7454, -1.2645,  0.9789,  ..., -0.9478, -0.0955, -0.0098],
         [-1.0414, -0.0137, -0.7522,  ...,  0.4737,  1.6196,  0.1940],
         [ 1.7506,  2.2319,  0.9934,  ..., -1.7916,  0.7520, -0.6998],
         ...,
         [ 0.5531, -0.2466,  0.2250,  ..., -0.1500, -1.5713, -0.0282],
         [-0.2143, -0.0576, -0.1151,  ..., -1.7637,  0.2469,  0.1862],
         [ 0.4525, -1.0666, -1.6398,  ...,  0.1193, -1.0125, -1.2960]],

        [[-0.7578,  0.5578, -0.7382,  ...,  1.3067, -0