In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [None]:
class CausalSelfAttention(nn.Module):
    """
    Causal Self-Attention module where each position can only attend to previous positions.

    This is used in decoder-only architectures like GPT, where the model should
    not see future tokens during training or inference.
    """

    def __init__(self, embedding_dim, dropout=0.1):
        """
        Initialize the causal self-attention module.

        Args:
            embedding_dim (int): The dimension of the input embeddings
            dropout (float): Dropout probability
        """
        super(CausalSelfAttention, self).__init__()

        self.embedding_dim = embedding_dim
        self.scale = math.sqrt(embedding_dim)

        # Linear projections for Q, K, V
        self.query = nn.Linear(embedding_dim, embedding_dim)
        self.key = nn.Linear(embedding_dim, embedding_dim)
        self.value = nn.Linear(embedding_dim, embedding_dim)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        """
        Forward pass for causal self-attention

        Args:
            x: Input tensor (batch_size, seq_len, embedding_dim)

        Returns:
            output: Causal self-attention output (batch_size, seq_len, embedding_dim)
            attention: Attention weights (batch_size, seq_len, seq_len)
        """
        batch_size, seq_len, _ = x.size()

        # Linear projections
        q = self.query(x)  # (batch_size, seq_len, embedding_dim)
        k = self.key(x)    # (batch_size, seq_len, embedding_dim)
        v = self.value(x)  # (batch_size, seq_len, embedding_dim)

        # Compute attention scores
        # (batch_size, seq_len, embedding_dim) @ (batch_size, embedding_dim, seq_len)
        # -> (batch_size, seq_len, seq_len)
        scores = torch.bmm(q, k.transpose(1, 2)) / self.scale

        # Create causal mask (lower triangular) to ensure each position only attends to previous positions
        causal_mask = torch.tril(torch.ones(seq_len, seq_len)).to(x.device)
        scores = scores.masked_fill(causal_mask == 0, -1e9)

        # Apply softmax to get attention weights
        attention = F.softmax(scores, dim=-1)
        attention = self.dropout(attention)

        # Apply attention weights to values
        # (batch_size, seq_len, seq_len) @ (batch_size, seq_len, embedding_dim)
        # -> (batch_size, seq_len, embedding_dim)
        output = torch.bmm(attention, v)

        return output, attention

In [None]:
# Example usage
def test_causal_attention():
    batch_size = 1        #4
    seq_len = 4           #10
    embedding_dim = 8     #64

    # Create random input tensor
    x = torch.randn(batch_size, seq_len, embedding_dim)

    # Initialize causal-attention modules
    causal_attn = CausalSelfAttention(embedding_dim)

    # Forward passes
    output, attention = causal_attn(x)

    print(f"Input shape: {x.shape}")
    print("Our Input tensor is --> ")
    print(x)

    print(f"Causal-attention output shape: {output.shape}")
    print("Our Output Context matrix  is --> ")
    print(output)

    print(f"Causal-attention weights shape: {attention.shape}")
    print("Our Output Attention matrix  is --> ")
    print(attention)

    # Verify that causal attention has the right pattern (lower triangular)
    print(f"Is causal mask working? {torch.all(torch.triu(attention[0], diagonal=1) == 0)}")

    return output, attention

if __name__ == "__main__":
    test_causal_attention()

Input shape: torch.Size([1, 4, 8])
Our Input tensor is --> 
tensor([[[ 1.3303, -1.1235,  1.1652, -0.0349,  0.1052, -1.6020,  0.6992,
           1.3815],
         [ 1.3089,  1.1137, -1.1362, -0.9922,  0.0993,  0.5333, -1.9262,
           1.2088],
         [-0.3219, -0.3680,  0.4476,  1.2465, -0.4465,  1.4749,  0.4591,
           1.4252],
         [-0.5158, -0.5374,  0.1797,  1.6236, -0.9445, -0.8403,  0.2905,
           1.3741]]])
Causal-attention output shape: torch.Size([1, 4, 8])
Our Output Context matrix  is --> 
tensor([[[ 0.2434,  0.3784, -0.5413, -0.1418,  1.0619, -0.0691,  0.6464,
          -0.6711],
         [ 0.0873,  0.1358, -0.1943, -0.0509,  0.3811, -0.0248,  0.2320,
          -0.2408],
         [-0.3893, -0.2056, -0.4329,  0.4260,  0.0843, -0.1285,  0.2085,
          -0.1206],
         [-0.0736, -0.3307, -0.5131,  0.8113,  0.5907,  0.1014,  0.6621,
           0.2256]]], grad_fn=<BmmBackward0>)
Causal-attention weights shape: torch.Size([1, 4, 4])
Our Output Attention matri