<h1>Self-Attention</h1>

Self-attention is a mechanism that calculates attention weights by analyzing the relationships between different parts of a single input sequence, such as words in a sentence or pixels in an image. It learns how these elements relate to each other within the same input.

Traditional Sequential models, like RNNs, process data sequentially and depend on the context built so far. These models have trouble maintaining context over long sequences due to issues like the exploding or vanishing gradient problem and lack of parallel processing capabilities. Another important thing to understand is that the importance of previous words generally doesn't depend on the order in which they appear.

Let's understand this and the need for self-attention from an example:

Consider the sentence: "The chef prepared a delicious meal, and it was served with wine." This sentence has 12 words, or tokens. If we focus on the word "it," we need to understand what "it" refers to in the context of the sentence. The words "was" and "served" are close to "it" in terms of proximity, but they don't help us understand the meaning of "it." Instead, the word "meal," which appears earlier in the sentence, is much more relevant because "it" refers to the "meal."

In this example, proximity isn't the key factor for understanding; the context provided by the word "meal" is what clarifies the meaning of "it."

When a computer processes this sentence, each word is represented as a token with a word embedding, which is a numerical vector representing the word's meaning. In a word embedding space, words with similar meanings or that are used in similar contexts have embeddings that are close to each other. For instance, the word "chef" might be close to "cook" and "kitchen," while "wine" might be close to "beverage" and "drink." To learn more about word embedding, visit [this link](#).

Initially, these embeddings don't capture the relationships between words like "it" and "wine" in our sentence; they lack context for our sentences. The goal of self-attention is to refine these general embeddings with the current context so that they include more context about the current sentence as a whole.

Using self-attention, we calculate how much each word in the sentence should "attend to" or focus on every other word to understand its context better. For the word "it" in our example, the attention mechanism would give more weight to "meal" than to "was" or "served" because "meal" provides the necessary context to understand what "it" refers to. Generally, we can say that by using self-attention, we effectively enhance the word embeddings so that they carry more contextual information, enabling the model to understand the sentence more accurately.

## Level 1 - Basic Self-Attention Calculations

We will understand the basic self-attention calculations without weights (we will add weights in the next section).

Think of the above sentence "The chef prepared a delicious meal, and it was served with wine." First, we calculate vocabulary and turn words into tokens \(\{t1, t2, t3, \ldots, tn\}\). These tokens are then converted into word embeddings \(\{e1, e2, e3, \ldots, en\}\). (Covered here in word embedding discussion).

Now these word embeddings are used pairwise (e.g., \(e2\) with all \(e1, e2, e3, \ldots, en\)) to get relevant attention scores \(\{a1, a2, a3, \ldots, an\}\) (how important all the other words are for word \(e2\)). For the word "it," the word "meal" will have the highest attention score. These scores are calculated using the dot product of word embeddings (we are not considering weights in this section).

Then the attention scores are converted to attention weights \(\{w1, w2, w3, \ldots, wn\}\) by normalizing the attention scores. In the last step, using these weights, a new representation for each word is generated as the weighted sum of all the other words. (Note weight for each word is defined by the attention/relevance scores).



### Steps for Self-Attention

1. **Tokenization & Embedding**:  
   Represent each word as a vector (tokens: `t1, t2, t3,... tn`, embeddings: `e1, e2, e3,... en`).

2. **Pairwise Comparison**:  
   Compute attention scores between word pairs (e.g., focus of "meal" when interpreting "it").

3. **Normalization**:  
   Apply softmax to convert attention scores to attention weights.

4. **Contextual Representation**:  
   Create a new context-aware word vector by using weighted sums of all embeddings.


In [5]:
### Self-Attention (Without Weights) ###
import torch
#Step 1
word_embeddings = torch.tensor(
  [[0.32, 0.68, 0.45], # The      (x^1)
   [0.71, 0.23, 0.89], # chef     (x^2)
   [0.55, 0.92, 0.37], # prepared (x^3)
   [0.18, 0.79, 0.60], # a        (x^4)
   [0.84, 0.41, 0.13], # delicious(x^5)
   [0.29, 0.63, 0.76], # meal     (x^6)
   [0.50, 0.15, 0.95], # and      (x^7)
   [0.67, 0.38, 0.82], # it       (x^8)
   [0.43, 0.91, 0.26], # was      (x^9)
   [0.75, 0.20, 0.58], # served   (x^10)
   [0.36, 0.72, 0.49], # with     (x^11)
   [0.88, 0.54, 0.11]] # wine     (x^12)
)

# Step 2 Calculate attention scores (Pairwise compare)
attn_scores = word_embeddings @ word_embeddings.T

# Step 3 Calculate attention weights (Normalize)
attn_weights = torch.softmax(attn_scores, dim=-1)

# Step 4 Calculate context vectors
all_context_vecs = attn_weights @ word_embeddings

In [6]:
print(all_context_vecs)

tensor([[0.5270, 0.5664, 0.5374],
        [0.5533, 0.5059, 0.5825],
        [0.5316, 0.5783, 0.5197],
        [0.5150, 0.5726, 0.5456],
        [0.5655, 0.5434, 0.5146],
        [0.5233, 0.5521, 0.5616],
        [0.5458, 0.5044, 0.5926],
        [0.5477, 0.5204, 0.5716],
        [0.5280, 0.5851, 0.5142],
        [0.5601, 0.5151, 0.5592],
        [0.5271, 0.5664, 0.5382],
        [0.5638, 0.5516, 0.5077]])


<h2>Add Weights to self-Attention</h2>

### Trainable Weights in Self-Attention

In self-attention, three weight matrices are used to focus on different aspects of the input:

1. **Query (Q) Matrix**:  
   Helps determine what the model should search for when analyzing each word.  
   (Q: What do you want?)

2. **Key (K) Matrix**:  
   Identifies which information within each word is relevant to the query.  
   (Q: Who do you get it from?)

3. **Value (V) Matrix**:  
   Contains the actual content or meaning of the words.  
   (Q: What do you get?)
   
Together, they help the model determine how much attention to give each word.


### Steps for Self-Attention (with Weight Matrices)

1. **Generate Query, Key, Value Vectors**:  
   Multiply input embeddings with Q, K, V weight matrices to get Query, Key, Value vectors.

2. **Compute Attention Scores**:  
   Matrix multiply Query and Key vectors to get raw attention score matrix.

3. **Normalize Scores**:  
   Apply softmax to raw attention scores to get attention weights.

4. **Generate Context Vector**:  
   Matrix multiply attention weights with the Value matrix to generate a context vector for each word.


In [22]:
### Self-Attention (with weights) ###
import torch
import torch.nn as nn

# Set random seed for reproducibility
torch.manual_seed(1240)

# Define input and output dimensions
d_in = word_embeddings.shape[1]  # Input embedding size (3 in this case)
d_out = 2  # Output embedding size

class SelfAttention(nn.Module):
    def __init__(self, d_in, d_out, use_bias=False):
        super().__init__()
        # Linear transformations for Query, Key, and Value
        self.W_query = nn.Linear(d_in, d_out, bias=use_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=use_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=use_bias)

    def forward(self, x):
        # Step 1 Transform input to Query, Key, and Value
        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)
        
        # Step 2 Calculate attention scores
        attn_scores = queries @ keys.T
        
        # Step 3 Apply scaling factor and softmax to get attention weights
        scaling_factor = keys.shape[-1] ** 0.5
        attn_weights = torch.softmax(attn_scores / scaling_factor, dim=-1)

        # Step 4 Compute context vectors
        context_vectors = attn_weights @ values
        return context_vectors

# Initialize the SelfAttention module
torch.manual_seed(1240)  # Reset seed for consistent results
self_attention = SelfAttention(d_in, d_out)

# Apply self-attention to word embeddings
result = self_attention(word_embeddings)
print(result)

tensor([[0.1348, 0.1801],
        [0.1358, 0.1782],
        [0.1361, 0.1776],
        [0.1346, 0.1803],
        [0.1358, 0.1782],
        [0.1349, 0.1798],
        [0.1348, 0.1799],
        [0.1359, 0.1780],
        [0.1355, 0.1788],
        [0.1355, 0.1787],
        [0.1351, 0.1796],
        [0.1362, 0.1774]], grad_fn=<MmBackward0>)


<h2>Causal Attention</h2>

### Causal Attention (Masked Attention)

In some tasks, like next-word prediction, the model needs to be **causal**, meaning it cannot use future information to predict the current word. For instance, an Auto Complete model can only see past words when predicting the next one.

Causal Attention (or Masked Attention) achieves this by using a **causal mask** that hides future words during the process. This mask sets the attention scores for future words to zero before the model applies attention weights.

### Summarizing Steps for Causal Attention

1. **Compute Attention Scores**:  
   Calculate attention scores between word positions, as done in self-attention.

2. **Apply Causal Mask**:  
   Mask out the scores for future words by setting them to negative infinity (-∞) to ensure they are ignored. This is done by masking scores above the diagonal in the attention matrix.

3. **Normalize and Apply Attention**:  
   After applying the causal mask, normalize the scores using softmax. Future words will have zero attention weight because their scores were set to negative infinity. The resulting attention weights are then used by the model.


In [27]:
### Caual Attention with causal and dropout mask on batched input###

import torch
import torch.nn as nn

# Create a batch by stacking the word embeddings twice
batch = torch.stack((word_embeddings, word_embeddings), dim=0)
print(batch.shape)  # Shape: (2, 12, 3) - 2 inputs, 12 tokens each, 3-dimensional embeddings

class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        # Linear transformations for Query, Key, and Value
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        # Dropout layer for regularization
        self.dropout = nn.Dropout(dropout)
        # Create an upper triangular mask to ensure causality
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))

    def forward(self, x):
        b, num_tokens, d_in = x.shape  # b: batch size, num_tokens: sequence length, d_in: input dimension
        
        # Transform input to Query, Key, and Value
        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)

        # Calculate attention scores
        attn_scores = queries @ keys.transpose(1, 2)
        
        # Apply causal mask to prevent attending to future tokens
        attn_scores.masked_fill_(
            self.mask.bool()[:num_tokens, :num_tokens], 
            -torch.inf
        )
        
        # Apply scaling factor and softmax to get attention weights
        scaling_factor = keys.shape[-1] ** 0.5
        attn_weights = torch.softmax(attn_scores / scaling_factor, dim=-1)
        
        # Apply dropout to attention weights
        attn_weights = self.dropout(attn_weights)

        # Compute context vectors
        context_vectors = attn_weights @ values
        return context_vectors

# Set random seed for reproducibility
torch.manual_seed(1240)

# Initialize CausalAttention module
context_length = batch.shape[1]
d_in = batch.shape[2]
d_out = 2  # Output embedding size
causal_attention = CausalAttention(d_in, d_out, context_length, dropout=0.0)

# Apply causal attention to the batch
context_vectors = causal_attention(batch)

print(context_vectors)
print("context_vectors.shape:", context_vectors.shape)

torch.Size([2, 12, 3])
tensor([[[0.0872, 0.2233],
         [0.1996, 0.0526],
         [0.1382, 0.1952],
         [0.1382, 0.1830],
         [0.1097, 0.2292],
         [0.1286, 0.1924],
         [0.1608, 0.1259],
         [0.1748, 0.1068],
         [0.1528, 0.1494],
         [0.1562, 0.1402],
         [0.1502, 0.1500],
         [0.1362, 0.1774]],

        [[0.0872, 0.2233],
         [0.1996, 0.0526],
         [0.1382, 0.1952],
         [0.1382, 0.1830],
         [0.1097, 0.2292],
         [0.1286, 0.1924],
         [0.1608, 0.1259],
         [0.1748, 0.1068],
         [0.1528, 0.1494],
         [0.1562, 0.1402],
         [0.1502, 0.1500],
         [0.1362, 0.1774]]], grad_fn=<UnsafeViewBackward0>)
context_vectors.shape: torch.Size([2, 12, 2])


<h2> Multi-Head Attention </h2>

### Multi-Head Attention

The core concept of multi-head attention is to apply the attention mechanism several times in parallel, each with its own set of learned linear transformations. This approach enables the model to simultaneously focus on information from different representation subspaces and at various positions within the input, capturing a richer set of dependencies and patterns.  
[Source- Attention is all you need. NeurIPS, 2017]

Let's go back to our old example, "The chef prepared a delicious meal, and it was served with wine." If we focus on the word "meal," we can see that there are several other words in the sentence that have significant relevance to it. Specifically, the words "chef," "prepared," "delicious," and "wine" all provide important context that enhances our understanding of the word "meal."

- The word **"chef"** tells us who is responsible for making the meal.
- The word **"prepared"** indicates the action taken to create the meal.
- The word **"delicious"** describes the quality of the meal.
- The phrase **"with wine"** suggests what accompanies the meal when served.

In a simple attention mechanism, a single attention head might struggle to simultaneously capture all these nuances and associations. This could lead to missing some relevant words or not fully capturing the relationships between "meal" and other words. However, with multi-head attention, we can allocate different attention heads to focus on different aspects of the sentence. For instance, one head might focus on the relationship between "meal" and "chef," another on "meal" and "delicious," and yet another on "meal" and "wine." By distributing the attention across multiple heads, the model can capture a richer and more nuanced understanding of the word "meal" in context. This richer understanding of all relevant words and relationships enhances the model's ability.

### Summarizing Steps for Multi-Head Attention

1. **Learn Different Relationships**:  
   Run multiple heads of attention mechanism on the input. Each head will use different weight matrices and generate different attention weights.

2. **Combine Attention Scores**:  
   All the attention scores from the multiple heads are concatenated and then combined using linear layer projection.

3. **Generate Context Vector**:  
   The combined attention score from step 2 is used to generate the final enriched representation of the input data.


In [39]:
### Multi-Head Attention ###
import torch
import torch.nn as nn


class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert d_out % num_heads == 0, "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads  # Dimension of each attention head

        # Linear projections for Query, Key, and Value
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        
        # Output projection
        self.out_proj = nn.Linear(d_out, d_out)
        
        # Dropout for regularization
        self.dropout = nn.Dropout(dropout)
        
        # Causal mask to prevent attending to future tokens
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape  # b: batch size, num_tokens: sequence length

        # Linear projections
        keys = self.W_key(x)      # Shape: (b, num_tokens, d_out)
        queries = self.W_query(x) # Shape: (b, num_tokens, d_out)
        values = self.W_value(x)  # Shape: (b, num_tokens, d_out)

        # Reshape for multi-head attention
        # (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)

        # Transpose for attention computation
        # (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        # Compute scaled dot-product attention
        attn_scores = queries @ keys.transpose(2, 3)  # Shape: (b, num_heads, num_tokens, num_tokens)

        # Apply causal mask
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
        attn_scores.masked_fill_(mask_bool, -torch.inf)
        
        # Compute attention weights
        attn_weights = torch.softmax(attn_scores / (self.head_dim ** 0.5), dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Apply attention weights to values
        context_vec = attn_weights @ values  # Shape: (b, num_heads, num_tokens, head_dim)
        
        # Reshape and combine heads
        context_vec = context_vec.transpose(1, 2).contiguous()  # (b, num_tokens, num_heads, head_dim)
        context_vec = context_vec.view(b, num_tokens, self.d_out)  # (b, num_tokens, d_out)
        
        # Final output projection
        context_vec = self.out_proj(context_vec)

        return context_vec

# Set random seed for reproducibility
torch.manual_seed(1240)

# Get dimensions from the batch
batch_size, context_length, d_in = batch.shape
d_out = 2  # Output dimension

# Initialize MultiHeadAttention
mha = MultiHeadAttention(d_in, d_out, context_length, dropout=0.0, num_heads=2)

# Apply multi-head attention to the batch
context_vecs = mha(batch)

print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

tensor([[[ 0.6238, -0.3816],
         [ 0.4784, -0.3510],
         [ 0.5800, -0.3691],
         [ 0.5747, -0.3686],
         [ 0.6151, -0.3764],
         [ 0.5870, -0.3709],
         [ 0.5370, -0.3618],
         [ 0.5201, -0.3580],
         [ 0.5527, -0.3642],
         [ 0.5459, -0.3633],
         [ 0.5536, -0.3648],
         [ 0.5761, -0.3689]],

        [[ 0.6238, -0.3816],
         [ 0.4784, -0.3510],
         [ 0.5800, -0.3691],
         [ 0.5747, -0.3686],
         [ 0.6151, -0.3764],
         [ 0.5870, -0.3709],
         [ 0.5370, -0.3618],
         [ 0.5201, -0.3580],
         [ 0.5527, -0.3642],
         [ 0.5459, -0.3633],
         [ 0.5536, -0.3648],
         [ 0.5761, -0.3689]]], grad_fn=<ViewBackward0>)
context_vecs.shape: torch.Size([2, 12, 2])


## Cross-Attention Mechanism

Cross-attention allows a model to focus on relevant parts of one sequence while processing another. When we compare it to self-attention, where a model weighs the importance of different parts of the same input sequence (i.e., token to token in a sentence), cross-attention relates to different inputs. The two sequences could belong to different modalities or represent different sets of information.

## Example

For our example, we will break our original sentence into two sequences and compute cross-attention between those two sequences. 

We split the sentence into two parts:

- \( x_1 \): "The chef prepared a delicious meal" (first 6 words)
- \( x_2 \): "and it was served with wine" (last 6 words)

The cross-attention mechanism will allow the first part of the sentence to attend to the second part. This means each word in the first part will compute attention weights for each word in the second part.

### Illustration

In a cross-attention scenario, each word in \( x_1 \) will focus on and compute relevance scores with every word in \( x_2 \). This allows the model to incorporate information from \( x_2 \) into the processing of \( x_1 \).



In [45]:
### Cross Attention ###
import torch
import torch.nn as nn

word_embeddings = torch.tensor(
  [[0.32, 0.68, 0.45], # The      (x^1)
   [0.71, 0.23, 0.89], # chef     (x^2)
   [0.55, 0.92, 0.37], # prepared (x^3)
   [0.18, 0.79, 0.60], # a        (x^4)
   [0.84, 0.41, 0.13], # delicious(x^5)
   [0.29, 0.63, 0.76], # meal     (x^6)
   [0.50, 0.15, 0.95], # and      (x^7)
   [0.67, 0.38, 0.82], # it       (x^8)
   [0.43, 0.91, 0.26], # was      (x^9)
   [0.75, 0.20, 0.58], # served   (x^10)
   [0.36, 0.72, 0.49], # with     (x^11)
   [0.88, 0.54, 0.11]] # wine     (x^12)
)

class CrossAttention(nn.Module):
    def __init__(self, d_in, d_out_kq, d_out_v):
        super().__init__()
        self.d_out_kq = d_out_kq
        
        # Learnable weight matrices for Query, Key, and Value
        self.W_query = nn.Parameter(torch.rand(d_in, d_out_kq))
        self.W_key = nn.Parameter(torch.rand(d_in, d_out_kq))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out_v))

    def forward(self, x_1, x_2):
        # Compute queries from x_1
        queries = x_1 @ self.W_query
        
        # Compute keys and values from x_2
        keys = x_2 @ self.W_key
        values = x_2 @ self.W_value
        
        # Compute attention scores
        attn_scores = queries @ keys.T
        
        # Apply scaling factor and softmax to get attention weights
        scaling_factor = self.d_out_kq ** 0.5
        attn_weights = torch.softmax(attn_scores / scaling_factor, dim=-1)
        
        # Compute context vectors
        context_vectors = attn_weights @ values
        
        return context_vectors

# Set random seed for reproducibility
torch.manual_seed(42)

# Define dimensions
d_in = word_embeddings.shape[1]  # Input dimension (3 in this case)
d_out_kq = 4  # Output dimension for queries and keys
d_out_v = 2   # Output dimension for values

# Initialize CrossAttention module
cross_attention = CrossAttention(d_in, d_out_kq, d_out_v)

# Split the sentence into two parts for demonstration
x_1 = word_embeddings[:6]  # "The chef prepared a delicious meal"
x_2 = word_embeddings[6:]  # "and it was served with wine"

# Apply cross-attention
result = cross_attention(x_1, x_2)

print("Cross-attention result shape:", result.shape)
print("Cross-attention result:\n", result)

# Interpret the results
for i, word in enumerate(["The", "chef", "prepared", "a", "delicious", "meal"]):
    print(f"{word}: {result[i].tolist()}")

Cross-attention result shape: torch.Size([6, 2])
Cross-attention result:
 tensor([[0.5326, 0.2634],
        [0.5321, 0.2654],
        [0.5345, 0.2637],
        [0.5325, 0.2636],
        [0.5334, 0.2634],
        [0.5322, 0.2642]], grad_fn=<MmBackward0>)
The: [0.5325765609741211, 0.26338499784469604]
chef: [0.5320825576782227, 0.2653651535511017]
prepared: [0.534509539604187, 0.2637292146682739]
a: [0.5324926376342773, 0.2635837197303772]
delicious: [0.533425509929657, 0.2634294927120209]
meal: [0.532193660736084, 0.26424530148506165]


In [49]:
### Cross Attention ###
import torch
import torch.nn as nn

class CrossAttention(nn.Module):
    def __init__(self, d_in, d_out_kq, d_out_v):
        super().__init__()
        self.d_out_kq = d_out_kq
        
        # Learnable weight matrices for Query, Key, and Value
        self.W_query = nn.Parameter(torch.rand(d_in, d_out_kq))
        self.W_key = nn.Parameter(torch.rand(d_in, d_out_kq))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out_v))

    def forward(self, x_1, x_2):
        # x_1: first input sequence (for queries)
        # x_2: second input sequence (for keys and values)
        
        # Compute queries from x_1
        queries = x_1 @ self.W_query  # Shape: (seq_len_1, d_out_kq)
        
        # Compute keys and values from x_2
        keys = x_2 @ self.W_key       # Shape: (seq_len_2, d_out_kq)
        values = x_2 @ self.W_value   # Shape: (seq_len_2, d_out_v)
        
        # Compute attention scores
        attn_scores = queries @ keys.T  # Shape: (seq_len_1, seq_len_2)
        
        # Apply scaling factor and softmax to get attention weights
        scaling_factor = self.d_out_kq ** 0.5
        attn_weights = torch.softmax(attn_scores / scaling_factor, dim=-1)
        
        # Compute context vectors
        context_vectors = attn_weights @ values  # Shape: (seq_len_1, d_out_v)
        
        return context_vectors

# Example usage:
torch.manual_seed(42)  # For reproducibility

# Define dimensions
d_in = 3       # Input dimension
d_out_kq = 4   # Output dimension for queries and keys
d_out_v = 2    # Output dimension for values

# Create random input sequences
seq_len_1 = 5
seq_len_2 = 6
x_1 = torch.rand(seq_len_1, d_in)
x_2 = torch.rand(seq_len_2, d_in)

# Initialize CrossAttention module
cross_attention = CrossAttention(d_in, d_out_kq, d_out_v)

# Apply cross-attention
result = cross_attention(x_1, x_2)

print("Cross-attention result shape:", result.shape)
print("Cross-attention result:\n", result)

Cross-attention result shape: torch.Size([5, 2])
Cross-attention result:
 tensor([[0.4803, 1.2236],
        [0.4700, 1.2026],
        [0.4785, 1.2194],
        [0.4565, 1.1745],
        [0.4840, 1.2308]], grad_fn=<MmBackward0>)
