In [1]:
!python --version


Python 3.12.7


In [2]:
pip install torch

Collecting torch
  Using cached torch-2.5.1-cp312-none-macosx_11_0_arm64.whl.metadata (28 kB)
Collecting filelock (from torch)
  Using cached filelock-3.16.1-py3-none-any.whl.metadata (2.9 kB)
Collecting typing-extensions>=4.8.0 (from torch)
  Using cached typing_extensions-4.12.2-py3-none-any.whl.metadata (3.0 kB)
Collecting networkx (from torch)
  Using cached networkx-3.4.2-py3-none-any.whl.metadata (6.3 kB)
Collecting jinja2 (from torch)
  Using cached jinja2-3.1.4-py3-none-any.whl.metadata (2.6 kB)
Collecting fsspec (from torch)
  Using cached fsspec-2024.10.0-py3-none-any.whl.metadata (11 kB)
Collecting setuptools (from torch)
  Using cached setuptools-75.5.0-py3-none-any.whl.metadata (6.8 kB)
Collecting sympy==1.13.1 (from torch)
  Using cached sympy-1.13.1-py3-none-any.whl.metadata (12 kB)
Collecting mpmath<1.4,>=1.1.0 (from sympy==1.13.1->torch)
  Using cached mpmath-1.3.0-py3-none-any.whl.metadata (8.6 kB)
Collecting MarkupSafe>=2.0 (from jinja2->torch)
  Using cached MarkupS

In [3]:
import torch
import torch.nn as nn
import math

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        """
        Args:
            d_model (int): Model dimension/embedding size
            num_heads (int): Number of attention heads
            dropout (float): Dropout probability
        """
        super(MultiHeadAttention, self).__init__()
        
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        
        # Output projection
        self.W_o = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout)
        
    def split_heads(self, x, batch_size):
        """
        Split the last dimension into (num_heads, d_k) and transpose to (batch_size, num_heads, seq_length, d_k)
        """
        x = x.view(batch_size, -1, self.num_heads, self.d_k)
        return x.transpose(1, 2)
    
    def scaled_dot_product_attention(self, Q, K, V, mask=None):
    
        attention_scores = torch.matmul(Q, K.transpose(-2, -1))
        attention_scores = attention_scores / math.sqrt(self.d_k)
        
        # Apply mask if provided
        # if mask is not None:
        #     attention_scores = attention_scores.masked_fill(mask == 0, -1e9)
        

        attention_weights = torch.softmax(attention_scores, dim=-1)
        attention_weights = self.dropout(attention_weights)
        

        attention_output = torch.matmul(attention_weights, V)
        
        return attention_output, attention_weights
    
    def forward(self, x, mask=None):
    
        batch_size = x.size(0)
        
        # Linear projections
        Q = self.W_q(x)  # (batch_size, seq_length, d_model)
        K = self.W_k(x)  # (batch_size, seq_length, d_model)
        V = self.W_v(x)  # (batch_size, seq_length, d_model)
        
        # Split heads
        Q = self.split_heads(Q, batch_size)  # (batch_size, num_heads, seq_length, d_k)
        K = self.split_heads(K, batch_size)  # (batch_size, num_heads, seq_length, d_k)
        V = self.split_heads(V, batch_size)  # (batch_size, num_heads, seq_length, d_k)
        
        # Calculate attention
        attention_output, attention_weights = self.scaled_dot_product_attention(Q, K, V, mask)
        
        # Combine heads
        attention_output = attention_output.transpose(1, 2)  # (batch_size, seq_length, num_heads, d_k)
        attention_output = attention_output.contiguous().view(batch_size, -1, self.d_model)
        
        # Final linear projection
        output = self.W_o(attention_output)
        
        return output, attention_weights


def example_usage():

    batch_size = 32
    seq_length = 10
    d_model = 512
    num_heads = 8
    

    x = torch.randn(batch_size, seq_length, d_model)
    

    multihead_attention = MultiHeadAttention(d_model, num_heads)
    

    output, attention_weights = multihead_attention(x)
    

    print(f"Input shape: {x.shape}")
    print(f"Output shape: {output.shape}")
    print(f"Attention weights shape: {attention_weights.shape}")

if __name__ == "__main__":
    example_usage()

  cpu = _conversion_method_template(device=torch.device("cpu"))


Input shape: torch.Size([32, 10, 512])
Output shape: torch.Size([32, 10, 512])
Attention weights shape: torch.Size([32, 8, 10, 10])
