## Multi-head self-attention

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        """
        Initialize Self-Attention module.

        Parameters:
        - embed_size: Dimensionality of input embeddings.
        - heads: Number of attention heads (multi-head attention splits embed_size into smaller heads).
        """
        super(SelfAttention, self).__init__()
        
        # Number of heads for multi-head attention
        self.heads = heads
        self.head_dim = embed_size // heads  # Splitting embedding into smaller dimensions for each head
        
        assert self.head_dim * heads == embed_size, "Embedding size must be divisible by number of heads"

        # Learnable weight matrices for Query, Key, and Value transformations
        self.W_q = nn.Linear(embed_size, embed_size, bias=False)  # Query weight matrix
        self.W_k = nn.Linear(embed_size, embed_size, bias=False)  # Key weight matrix
        self.W_v = nn.Linear(embed_size, embed_size, bias=False)  # Value weight matrix
        self.fc_out = nn.Linear(embed_size, embed_size)  # Final output projection

    def forward(self, x):
        """
        Forward pass for Self-Attention.

        Parameters:
        - x: Input tensor of shape (batch_size, sequence_length, embed_size)

        Returns:
        - out: Self-attention output of shape (batch_size, sequence_length, embed_size)
        """
        batch_size, seq_length, embed_size = x.shape

        # Transform input embeddings into Q, K, V matrices
        Q = self.W_q(x)  # Query matrix (batch_size, seq_length, embed_size)
        K = self.W_k(x)  # Key matrix (batch_size, seq_length, embed_size)
        V = self.W_v(x)  # Value matrix (batch_size, seq_length, embed_size)

        # Reshape Q, K, V to (batch_size, heads, seq_length, head_dim) for multi-head processing
        Q = Q.view(batch_size, seq_length, self.heads, self.head_dim).transpose(1, 2)
        K = K.view(batch_size, seq_length, self.heads, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, seq_length, self.heads, self.head_dim).transpose(1, 2)

        # Compute Scaled Dot-Product Attention
        attention_scores = torch.matmul(Q, K.transpose(-2, -1))  # Dot product QK^T (batch_size, heads, seq_length, seq_length)
        attention_scores = attention_scores / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))  # Scale by sqrt(d_k)
        
        attention_weights = F.softmax(attention_scores, dim=-1)  # Apply softmax to normalize scores

        # Compute attention-weighted sum of values
        out = torch.matmul(attention_weights, V)  # (batch_size, heads, seq_length, head_dim)

        # Reshape back to original shape (batch_size, seq_length, embed_size)
        out = out.transpose(1, 2).contiguous().view(batch_size, seq_length, embed_size)

        # Final linear transformation
        out = self.fc_out(out)

        return out



In [4]:
# Example usage
if __name__ == "__main__":
    batch_size = 2
    seq_length = 5
    embed_size = 16
    heads = 4

    x = torch.rand(batch_size, seq_length, embed_size)  # Example input embeddings

    self_attention = SelfAttention(embed_size, heads)
    output = self_attention(x)

    print("Self-Attention Output Shape:", output.shape)  # Expected: (batch_size, seq_length, embed_size)    
    print("Self-Attention Output:", output) 

Self-Attention Output Shape: torch.Size([2, 5, 16])
Self-Attention Output: tensor([[[ 4.0854e-01,  1.0211e-01, -1.9194e-01,  7.7359e-02,  4.2792e-01,
          -4.5138e-01,  4.4197e-02, -3.5378e-01,  1.9782e-01, -1.0165e-01,
          -2.4444e-01, -4.8325e-03,  1.3205e-01, -1.8331e-01, -3.3692e-01,
          -1.3270e-01],
         [ 4.0888e-01,  1.0327e-01, -1.9213e-01,  7.6321e-02,  4.2641e-01,
          -4.5309e-01,  4.3546e-02, -3.5363e-01,  2.0017e-01, -1.0099e-01,
          -2.4342e-01, -5.2806e-03,  1.3193e-01, -1.8099e-01, -3.3793e-01,
          -1.3164e-01],
         [ 4.0837e-01,  1.0174e-01, -1.9085e-01,  7.7388e-02,  4.2804e-01,
          -4.5011e-01,  4.1277e-02, -3.5331e-01,  1.9723e-01, -1.0178e-01,
          -2.4482e-01, -4.1396e-03,  1.3138e-01, -1.7882e-01, -3.3640e-01,
          -1.3221e-01],
         [ 4.0903e-01,  1.0549e-01, -1.9447e-01,  7.6325e-02,  4.2731e-01,
          -4.5252e-01,  4.2496e-02, -3.5537e-01,  1.9656e-01, -1.0277e-01,
          -2.4426e-01, -6.67

## Masked self-attention

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MaskedSelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        """
        Initialize Masked Self-Attention module.

        Parameters:
        - embed_size: Dimensionality of input embeddings.
        - heads: Number of attention heads (multi-head attention splits embed_size into smaller heads).
        """
        super(MaskedSelfAttention, self).__init__()
        
        # Number of heads for multi-head attention
        self.heads = heads
        self.head_dim = embed_size // heads  # Splitting embedding into smaller dimensions for each head
        
        assert self.head_dim * heads == embed_size, "Embedding size must be divisible by number of heads"

        # Learnable weight matrices for Query, Key, and Value transformations
        self.W_q = nn.Linear(embed_size, embed_size, bias=False)  # Query weight matrix
        self.W_k = nn.Linear(embed_size, embed_size, bias=False)  # Key weight matrix
        self.W_v = nn.Linear(embed_size, embed_size, bias=False)  # Value weight matrix
        self.fc_out = nn.Linear(embed_size, embed_size)  # Final output projection

    def forward(self, x, mask):
        """
        Forward pass for Masked Self-Attention.

        Parameters:
        - x: Input tensor of shape (batch_size, sequence_length, embed_size)
        - mask: Mask tensor of shape (batch_size, 1, seq_length, seq_length)

        Returns:
        - out: Masked self-attention output of shape (batch_size, sequence_length, embed_size)
        """
        batch_size, seq_length, embed_size = x.shape

        # Transform input embeddings into Q, K, V matrices
        Q = self.W_q(x)  # Query matrix (batch_size, seq_length, embed_size)
        K = self.W_k(x)  # Key matrix (batch_size, seq_length, embed_size)
        V = self.W_v(x)  # Value matrix (batch_size, seq_length, embed_size)

        # Reshape Q, K, V to (batch_size, heads, seq_length, head_dim) for multi-head processing
        Q = Q.view(batch_size, seq_length, self.heads, self.head_dim).transpose(1, 2)
        K = K.view(batch_size, seq_length, self.heads, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, seq_length, self.heads, self.head_dim).transpose(1, 2)

        # Compute Scaled Dot-Product Attention
        attention_scores = torch.matmul(Q, K.transpose(-2, -1))  # Dot product QK^T (batch_size, heads, seq_length, seq_length)
        attention_scores = attention_scores / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))  # Scale by sqrt(d_k)

        # Apply mask (set future words' scores to -inf)
        attention_scores = attention_scores.masked_fill(mask == 0, float('-inf'))

        # Apply softmax to normalize scores
        attention_weights = F.softmax(attention_scores, dim=-1)  

        # Compute attention-weighted sum of values
        out = torch.matmul(attention_weights, V)  # (batch_size, heads, seq_length, head_dim)

        # Reshape back to original shape (batch_size, seq_length, embed_size)
        out = out.transpose(1, 2).contiguous().view(batch_size, seq_length, embed_size)

        # Final linear transformation
        out = self.fc_out(out)

        return out


In [6]:
# Function to create a mask for causal self-attention
def create_causal_mask(seq_length):
    """
    Creates a lower triangular mask for causal (masked) self-attention.
    
    - Returns a tensor of shape (1, 1, seq_length, seq_length) where future tokens are masked.
    """
    mask = torch.tril(torch.ones(seq_length, seq_length))  # Lower triangular matrix (seq_length, seq_length)
    return mask.unsqueeze(0).unsqueeze(0)  # Shape: (1, 1, seq_length, seq_length)


In [7]:
# Example usage
if __name__ == "__main__":
    batch_size = 2
    seq_length = 5
    embed_size = 16
    heads = 4

    x = torch.rand(batch_size, seq_length, embed_size)  # Example input embeddings
    mask = create_causal_mask(seq_length)  # Generate causal mask

    masked_self_attention = MaskedSelfAttention(embed_size, heads)
    output = masked_self_attention(x, mask)

    print("Masked Self-Attention Output Shape:", output.shape)  # Expected: (batch_size, seq_length, embed_size)
    print("Masked Self-Attention Output:", output)  # Expected: (batch_size, seq_length, embed_size)
    print("Masked Self-Attention Mask:", mask)  # Expected: (batch_size, seq_length, embed_size)

Masked Self-Attention Output Shape: torch.Size([2, 5, 16])
Masked Self-Attention Output: tensor([[[ 0.0433,  0.4703, -0.1431,  0.2472,  0.2905, -0.2199,  0.2942,
           0.1567, -0.0184,  0.2952, -0.0145, -0.1265,  0.3437, -0.0480,
           0.1510, -0.4225],
         [ 0.0464,  0.3199, -0.0126,  0.1766,  0.3083, -0.1797,  0.3126,
           0.2142,  0.0722,  0.3787, -0.1313, -0.0768,  0.3029,  0.0247,
           0.1271, -0.5003],
         [-0.0370,  0.2891, -0.0260,  0.1492,  0.2494, -0.1769,  0.2930,
           0.2498,  0.0917,  0.3124, -0.1894, -0.1639,  0.2779,  0.0321,
           0.1695, -0.5547],
         [-0.0650,  0.2304, -0.0118,  0.1525,  0.2306, -0.1136,  0.3311,
           0.2630,  0.0810,  0.3415, -0.1704, -0.2008,  0.2072,  0.0012,
           0.1693, -0.4700],
         [-0.0916,  0.2251, -0.0165,  0.1363,  0.2353, -0.1044,  0.3241,
           0.2303,  0.0852,  0.2971, -0.1761, -0.2521,  0.1796, -0.0121,
           0.1802, -0.4444]],

        [[-0.0939,  0.2634, -0.099

## Cross attention

In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class CrossAttention(nn.Module):
    def __init__(self, embed_size, heads):
        """
        Initialize Cross-Attention module.

        Parameters:
        - embed_size: Dimensionality of input embeddings.
        - heads: Number of attention heads (multi-head attention splits embed_size into smaller heads).
        """
        super(CrossAttention, self).__init__()
        
        # Number of heads for multi-head attention
        self.heads = heads
        self.head_dim = embed_size // heads  # Splitting embedding into smaller dimensions for each head
        
        assert self.head_dim * heads == embed_size, "Embedding size must be divisible by number of heads"

        # Learnable weight matrices for Query (decoder), Key (encoder), and Value (encoder)
        self.W_q = nn.Linear(embed_size, embed_size, bias=False)  # Query weight matrix (from decoder)
        self.W_k = nn.Linear(embed_size, embed_size, bias=False)  # Key weight matrix (from encoder)
        self.W_v = nn.Linear(embed_size, embed_size, bias=False)  # Value weight matrix (from encoder)
        self.fc_out = nn.Linear(embed_size, embed_size)  # Final output projection

    def forward(self, decoder_x, encoder_x):
        """
        Forward pass for Cross-Attention.

        Parameters:
        - decoder_x: Decoder input tensor of shape (batch_size, target_seq_length, embed_size)
        - encoder_x: Encoder output tensor of shape (batch_size, source_seq_length, embed_size)

        Returns:
        - out: Cross-attention output of shape (batch_size, target_seq_length, embed_size)
        """
        batch_size, target_seq_length, embed_size = decoder_x.shape
        source_seq_length = encoder_x.shape[1]  # Length of the encoder's output sequence

        # Transform decoder input into Query (Q), and encoder output into Key (K) and Value (V)
        Q = self.W_q(decoder_x)  # Query (from decoder)
        K = self.W_k(encoder_x)  # Key (from encoder)
        V = self.W_v(encoder_x)  # Value (from encoder)

        # Reshape Q, K, V for multi-head attention (batch_size, heads, seq_length, head_dim)
        Q = Q.view(batch_size, target_seq_length, self.heads, self.head_dim).transpose(1, 2)
        K = K.view(batch_size, source_seq_length, self.heads, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, source_seq_length, self.heads, self.head_dim).transpose(1, 2)

        # Compute Scaled Dot-Product Attention
        attention_scores = torch.matmul(Q, K.transpose(-2, -1))  # Dot product QK^T (batch_size, heads, target_seq_length, source_seq_length)
        attention_scores = attention_scores / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))  # Scale by sqrt(d_k)

        # Apply softmax to normalize scores
        attention_weights = F.softmax(attention_scores, dim=-1)  

        # Compute attention-weighted sum of values
        out = torch.matmul(attention_weights, V)  # (batch_size, heads, target_seq_length, head_dim)

        # Reshape back to original shape (batch_size, target_seq_length, embed_size)
        out = out.transpose(1, 2).contiguous().view(batch_size, target_seq_length, embed_size)

        # Final linear transformation
        out = self.fc_out(out)

        return out


In [9]:
# Example usage
if __name__ == "__main__":
    batch_size = 2
    source_seq_length = 6  # Encoder output length
    target_seq_length = 4  # Decoder input length
    embed_size = 16
    heads = 4

    encoder_output = torch.rand(batch_size, source_seq_length, embed_size)  # Example encoder output
    decoder_input = torch.rand(batch_size, target_seq_length, embed_size)  # Example decoder input

    cross_attention = CrossAttention(embed_size, heads)
    output = cross_attention(decoder_input, encoder_output)

    print("Cross-Attention Output Shape:", output.shape)  # Expected: (batch_size, target_seq_length, embed_size)


Cross-Attention Output Shape: torch.Size([2, 4, 16])
