# Implement Attention from Scratch
### Problem Statement
Implement a **Scaled Dot-Product Attention** mechanism from scratch using PyTorch. Your mission (should you choose to accept it) is to replicate what PyTorch's built-in `scaled_dot_product_attention` does ‚Äî manually. This core component is essential in Transformer architectures and helps models focus on relevant parts of a sequence. You'll test your implementation against PyTorch's native one to ensure you nailed it.


### Requirements
1. **Define the Function**:
   - Create a function `scaled_dot_product_attention(q, k, v, mask=None)` that:
     - Computes attention scores via the dot product of query and key vectors.
     - Scales the scores using the square root of the key dimension.
     - Applies an optional mask to the scores.
     - Applies softmax to convert scores into attention weights.
     - Uses these weights to compute a weighted sum of values (V).

2. **Test Your Work**:
   - Use sample tensors for query (Q), key (K), and value (V).
   - Compare the result of your custom implementation with PyTorch's `F.scaled_dot_product_attention` using an `assert` to check numerical accuracy.


### Constraints
- ‚ùå Do NOT use `F.scaled_dot_product_attention` inside your custom function ‚Äî that defeats the whole point.
- ‚úÖ Your implementation must handle **batch dimensions** correctly.
- ‚úÖ Support optional **masking** for future tokens or padding.
- ‚úÖ Use only PyTorch ops ‚Äî no cheating with external attention libs.



<details>
  <summary>üí° Hint</summary>
  Use `torch.matmul()` to compute dot products and `F.softmax()` for the final attention weights. The mask (if used) should be applied **before** the softmax using `masked_fill`.
</details>


In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [24]:
# Synthetic data
torch.manual_seed(42)
q = torch.tensor([[[1.0, 0.5, 0.2],  # Query for token 1
                   [0.3, 1.2, 0.7],  # Query for token 2
                   [0.8, 0.1, 1.5]]])  # Query for token 3

k = torch.tensor([[[1.0, 0.2, 0.3],  # Key for token 1
                   [0.4, 1.5, 0.6],  # Key for token 2
                   [0.7, 0.1, 1.8]]])  # Key for token 3

v = torch.tensor([[[10.0, 2.0, 3.0],  # Value for token 1
                   [4.0, 15.0, 6.0],  # Value for token 2
                   [7.0, 1.0, 18.0]]])  # Value for token 3

print(q.shape)

torch.Size([1, 3, 3])


In [None]:
def scaled_dot_product_attention(q, k, v, mask=None):
    """
    Compute the scaled dot-product attention.
    
    Args:
        q: Query tensor of shape (..., seq_len_q, d_k)
        k: Key tensor of shape (..., seq_len_k, d_k)
        v: Value tensor of shape (..., seq_len_k, d_v)
        mask: Optional mask tensor of shape (..., seq_len_q, seq_len_k)
    
    Returns:
        output: Attention output tensor of shape (..., seq_len_q, d_v)
        attention_weights: Attention weights tensor of shape (..., seq_len_q, seq_len_k)
    """
    d_k = q.shape[-1]  # Get the last dimension size (key dimension)
    
    # Compute the dot product of Q and K^T
    scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
    
    # Apply mask if provided
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))
    
    # Apply softmax to get attention weights along the last dimension
    attention_weights = F.softmax(scores, dim=-1)  # dim=-1 ensures softmax is applied across the last axis
    
    # Compute output by weighting V with the attention weights
    output = torch.matmul(attention_weights, v)
    
    return output, attention_weights

In [None]:
# Testing on data & compare
output_custom, _ = scaled_dot_product_attention(q, k, v)
print(output_custom)
output = F.scaled_dot_product_attention(q, k, v)
print(output)

assert torch.allclose(output_custom, output, atol=1e-08, rtol=1e-05) # Check if they are close enough.


torch.Size([1, 3, 3])
tensor([[[ 6.9352,  6.2411,  8.8509],
         [ 6.1200,  8.0314,  9.2154],
         [ 6.9659,  4.0257, 12.7036]]])
tensor([[[ 6.9352,  6.2411,  8.8509],
         [ 6.1200,  8.0314,  9.2154],
         [ 6.9659,  4.0257, 12.7036]]])
