In [1]:
import torch ## torch let's us create tensors and also provides helper functions
import torch.nn as nn ## torch.nn gives us nn.module() and nn.Linear()
import torch.nn.functional as F # This gives us the softmax()

In [None]:
class Transformer_MaskSelfAttention(nn.Module):
    """
    Implements Masked Self-Attention mechanism for a Transformer model.
    The mask prevents attention to future tokens, ensuring causal attention.
    """

    def __init__(self, dimentiona_model=2, row_dimension=0, column_dimension=1):
        """
        Initializes the self-attention module with weight matrices for queries, keys, and values.

        Args:
            dimentiona_model (int): The dimensionality of the input embeddings (default=2).
            row_dimension (int): The row index used for tensor operations.
            column_dimension (int): The column index used for tensor operations.
        """
        super().__init__()

        # Linear transformation layers to create Queries (Q), Keys (K), and Values (V)
        self.weight_query = nn.Linear(in_features=dimentiona_model, out_features=dimentiona_model, bias=False)
        self.weight_key = nn.Linear(in_features=dimentiona_model, out_features=dimentiona_model, bias=False)
        self.weight_value = nn.Linear(in_features=dimentiona_model, out_features=dimentiona_model, bias=False)

        # Store row and column dimensions for tensor operations
        self.row_dimension = row_dimension
        self.column_dimension = column_dimension

    def forward(self, token_encodings, mask=None):
        """
        Computes masked self-attention for a given token encoding matrix.

        Args:
            token_encodings (Tensor): The input token embeddings of shape (sequence_length, d_model).
            mask (Tensor, optional): A boolean mask of shape (sequence_length, sequence_length) 
                                     where True values indicate positions to be masked.

        Returns:
            attention_scores (Tensor): The output attention values after applying masking.
        """

        # Compute Queries, Keys, and Values
        q = self.weight_query(token_encodings)  # Transform input embeddings into query vectors
        k = self.weight_key(token_encodings)    # Transform input embeddings into key vectors
        v = self.weight_value(token_encodings)  # Transform input embeddings into value vectors

        # Compute raw similarity scores (dot product between Q and K^T)
        similarity_scores = torch.matmul(q, k.transpose(dim0=self.row_dimension, dim1=self.column_dimension))

        # Scale similarity scores by sqrt(d_model) to stabilize gradients
        scaled_similarity_scores = similarity_scores / torch.tensor(k.size(self.column_dimension)**0.5, dtype=torch.float32)

        # Apply masking (if provided) to prevent attention to future tokens
        if mask is not None:
            scaled_similarity_scores = scaled_similarity_scores.masked_fill(mask, value=-1e9)  
            # `-1e9` ensures that masked values become effectively zero after softmax.

        # Compute attention weights using softmax
        attention_percents = F.softmax(scaled_similarity_scores, dim=self.column_dimension)

        # Compute final attention scores by multiplying attention weights with values (V)
        attention_scores = torch.matmul(attention_percents, v)

        return attention_scores

In [None]:
# Define a token encoding matrix (sequence_length=3, d_model=2)
# Each row represents a token's embedding with 2 feature dimensions
encoding_matrix = torch.tensor([[1.16,  0.23],  # Token 1
                                [0.57,  1.36],  # Token 2
                                [4.41, -2.16]]) # Token 3

# Set manual seed to ensure reproducibility of random weight initialization
torch.manual_seed(42)

# Instantiate Transformer Masked Self-Attention module
tmsat = Transformer_MaskSelfAttention(dimentiona_model=2, row_dimension=0, column_dimension=1)

# Create a lower triangular mask (prevents attending to future tokens)
mask = torch.tril(torch.ones(3,3))  # Generates a lower triangular matrix with 1s below the diagonal

# Convert to boolean mask: 
# True (masked positions - upper triangle) and False (allowed positions - lower triangle)
mask = mask == 0  

# The mask is now:
# [[False,  True,  True],   # Token 1 attends only to itself
#  [False, False,  True],   # Token 2 attends to itself and Token 1
#  [False, False, False]]   # Token 3 attends to all previous tokens

# Print the final mask for reference (if needed)
print(mask)

tensor([[False,  True,  True],
        [False, False,  True],
        [False, False, False]])

In [5]:
## calculate masked self-attention
tmsat(encoding_matrix, mask)

tensor([[ 0.6038,  0.7434],
        [-0.0062,  0.6072],
        [ 3.4989,  2.2427]], grad_fn=<MmBackward0>)

In [6]:
# Print weight matrices for Queries, Keys, and Values
print(f"""Weight matrix values for queries, keys, and values are as follows:
      weight:\n {tmsat.weight_query.weight.transpose(0,1)}
      key:\n {tmsat.weight_key.weight.transpose(0,1)}
      value:\n {tmsat.weight_value.weight.transpose(0,1)}\n""") 

# Compute Queries, Keys, and Values
q = tmsat.weight_query(encoding_matrix)
k = tmsat.weight_key(encoding_matrix)
v = tmsat.weight_value(encoding_matrix)  # Correct usage

# Print Queries, Keys, and Values
print(f"""Calculate Queries, Keys, and Values are as follows:
      Queries:\n {q}
      Keys:\n {k}
      Values:\n {v}\n""")   

# Compute Similarity Scores
similarity_scores = torch.matmul(q, k.transpose(0, 1))

# Compute Scaled Similarity Scores
scaling_factor = torch.tensor(k.size(1)**0.5, dtype=torch.float32)
scaled_similarity_scores = similarity_scores / scaling_factor

# Apply Masking
masked_scaled_similarity_scores = scaled_similarity_scores.masked_fill(mask, value=-1e9)  # Masking future tokens

# Compute Attention Percents using Softmax
attention_percents = F.softmax(masked_scaled_similarity_scores, dim=1)

# Compute Attention Scores
attention_scores = torch.matmul(attention_percents, v)

# Print Results
print(f"""Similarity, Scaled Similarity, Attention Percents, and Attention Scores are as follows:

Similarity Scores:
{similarity_scores}

Scaled Similarity Scores (before masking):
{scaled_similarity_scores}

Mask Used:
{mask}

Scaled Similarity Scores (after masking applied):
{masked_scaled_similarity_scores}

Attention Percentages (after Softmax):
{attention_percents}

Attention Scores:
{attention_scores}""")


Weight matrix values for queries, keys, and values are as follows:
      weight:
 tensor([[ 0.5406, -0.1657],
        [ 0.5869,  0.6496]], grad_fn=<TransposeBackward0>)
      key:
 tensor([[-0.1549, -0.3443],
        [ 0.1427,  0.4153]], grad_fn=<TransposeBackward0>)
      value:
 tensor([[ 0.6233,  0.6146],
        [-0.5188,  0.1323]], grad_fn=<TransposeBackward0>)

Calculate Queries, Keys, and Values are as follows:
      Queries:
 tensor([[ 0.7621, -0.0428],
        [ 1.1063,  0.7890],
        [ 1.1164, -2.1336]], grad_fn=<MmBackward0>)
      Keys:
 tensor([[-0.1469, -0.3038],
        [ 0.1057,  0.3685],
        [-0.9914, -2.4152]], grad_fn=<MmBackward0>)
      Values:
 tensor([[ 0.6038,  0.7434],
        [-0.3502,  0.5303],
        [ 3.8695,  2.4246]], grad_fn=<MmBackward0>)

Similarity, Scaled Similarity, Attention Percents, and Attention Scores are as follows:

Similarity Scores:
tensor([[-0.0990,  0.0648, -0.6523],
        [-0.4022,  0.4078, -3.0024],
        [ 0.4842, -0.6683, 