# Lecture 4 Self Attention - Masked Attention
## Example 2

#### Masked Attention

Masked attention is a crucial mechanism in transformer architectures, particularly in tasks like language modeling, where it's essential to prevent the model from accessing future tokens during training. This ensures that the predictions for a particular position depend only on the known outputs at positions before it.

Below are two Python examples demonstrating masked attention:
- From Scratch Implementation Using NumPy
- Using PyTorch's Built-in Modules

#### 2. Masked Attention Using PyTorch
Leveraging PyTorch's built-in nn.MultiheadAttention module simplifies the implementation of masked attention. PyTorch allows the creation of a causal mask to ensure that each position can only attend to previous positions and itself.

Step-by-Step Explanation
- Import Libraries and Set Parameters:
    - Define dimensions, number of heads, and other necessary parameters.
- Initialize the MultiheadAttention Module:
    - Create an instance of nn.MultiheadAttention with the specified embedding dimension and number of heads.
-Prepare Input Data:
    - PyTorch expects input shapes as (seq_length, batch_size, embedding_dim) unless batch_first=True is set.
    - Initialize random input tensors for queries, keys, and values.
- Create Causal Mask:
    - Generate a mask that prevents each position from attending to future positions.
- Perform Multi-Head Attention:
    - Pass the queries, keys, and values to the forward method of the MultiheadAttention module along with the mask.
    - Obtain the output and attention weights.
- Output the Results:
    - Display the attention output and the attention weights.

**Explanation of Output**
- Input X: A randomly generated input tensor with shape (batch_size, seq_length, embedding_dim).
- Attention Weights: After applying the mask, each position can only attend to itself and previous positions. For example, the first position attends only to itself, the second attends to the first and second, and so on.
- Output: The result of the masked attention mechanism, maintaining the same shape as the input.

Note: The actual numerical values may vary slightly due to random initialization.

In [1]:
import torch
import torch.nn as nn

# Parameters
torch.manual_seed(42)  # For reproducibility
batch_size = 1
seq_length = 4
embedding_dim = 8
num_heads = 2

# Initialize MultiheadAttention
multihead_attn = nn.MultiheadAttention(embed_dim = embedding_dim, num_heads = num_heads, batch_first=True)

# Random input
X = torch.rand(batch_size, seq_length, embedding_dim)

# In PyTorch's MultiheadAttention, when batch_first=True, input shape is (batch_size, seq_length, embedding_dim)
# Otherwise, it should be (seq_length, batch_size, embedding_dim)

# Create a causal mask to prevent attention to future tokens
# The mask shape should be (seq_length, seq_length)
mask = torch.triu(torch.ones(seq_length, seq_length), diagonal=1).bool()  # Upper triangular matrix

# Perform Multi-Head Attention
# Using the same tensor for queries, keys, and values (self-attention)
attn_output, attn_weights = multihead_attn(X, X, X, attn_mask = mask)

# Display Results
torch.set_printoptions(precision=4, sci_mode=False)
print("Input X:\n", X)
print("\nAttention Weights:\n", attn_weights)
print("\nAttention Output:\n", attn_output)

Input X:
 tensor([[[0.6855, 0.9696, 0.4295, 0.4961, 0.3849, 0.0825, 0.7400, 0.0036],
         [0.8104, 0.8741, 0.9729, 0.3821, 0.0892, 0.6124, 0.7762, 0.0023],
         [0.3865, 0.2003, 0.4563, 0.2539, 0.2956, 0.3413, 0.0248, 0.9103],
         [0.9192, 0.4216, 0.4431, 0.2959, 0.0485, 0.0134, 0.6858, 0.2255]]])

Attention Weights:
 tensor([[[1.0000, 0.0000, 0.0000, 0.0000],
         [0.5110, 0.4890, 0.0000, 0.0000],
         [0.3340, 0.3238, 0.3422, 0.0000],
         [0.2523, 0.2368, 0.2546, 0.2563]]], grad_fn=<MeanBackward1>)

Attention Output:
 tensor([[[    -0.0897,     -0.2322,      0.0044,      0.1931,     -0.2976,
              -0.0187,     -0.1550,     -0.0145],
         [    -0.1037,     -0.2087,     -0.0187,      0.1829,     -0.3419,
              -0.0234,     -0.2459,      0.0579],
         [    -0.0225,     -0.2682,      0.0002,      0.0759,     -0.2667,
              -0.0839,     -0.2246,     -0.0096],
         [    -0.0267,     -0.2767,      0.0029,      0.0618,     -0.2711

**Explanation of Output**
- Input X: A randomly generated input tensor with shape (batch_size, seq_length, embedding_dim).
- Attention Output: Each position in the sequence has been transformed based on the masked attention mechanism. The output maintains the same shape as the input.
- Attention Weights: The attention scores for each head. Due to the mask, positions can only attend to themselves and previous positions. For instance, the first token only attends to itself, the second token attends to the first and second, and so on. In this random example, the attention weights are uniformly distributed among the allowed positions.

Note: The actual numerical values may vary slightly due to random initialization.