In [1]:
from __future__ import annotations

import math

import torch
import torch.nn as nn
from torch import Tensor, BoolTensor
from torch.nn import functional as F

# Building the attention part

### Key point: 

attention allows modern neural networks to focus on the most relevant pieces of the input whether text, images, or multimodal inputs.

### Input embedding:

After word2vec, tokens are structed. 
embbeding_size = position_embedding, token vector = word embedding + potion embedding.
token vector -> linear layer -> hidden_size
more hidden size, more information, more parameters, more cost.Normally, = 512, 768, 1024
multi_ head optimizes hidden_size

## Single Head Self-Attention

In [2]:
class SingleHeadAttention(nn.Module):
    def __init__(self, hidden_size: int, bias: bool = True):
        super().__init__()
        
        # Linear layer for transforming input tensor to query, key, and value tensors
        #qkv_projected = input_tensor @ qkv_projection.weight^T + qkv_projection.bias
        #qkv_projection.weigh.shape = (hidden_size, (hidden_size // 4) * 3)
        self.qkv_projection = nn.Linear(hidden_size, (hidden_size // 4) * 3, bias=bias)
        
        # Linear layer for final output projection
        self.output_projection = nn.Linear(hidden_size // 4, hidden_size, bias=bias)

    def forward(self, input_tensor: Tensor):
        batch_size, sequence_length, hidden_size = input_tensor.shape
        
        # Project input tensor to query, key, and value tensors
        qkv_projected = self.qkv_projection(input_tensor)
        qkv_projected = qkv_projected.reshape(batch_size, sequence_length, 3, hidden_size // 4)
        q, k, v = qkv_projected.unbind(dim=2)
        
        # Compute attention weights using query and key tensors
        attention_weights = q @ k.transpose(-2, -1)
        attention_weights = attention_weights / torch.sqrt(torch.tensor(k.size(-1)))
        attention_weights = torch.softmax(attention_weights, dim=-1)
        
        # Apply attention weights to value tensor
        attended_values = attention_weights @ v
        
        # Project attended values to final output
        output_tensor = self.output_projection(attended_values)
        
        return output_tensor


#Q = input @ W_Q
#K = input @ W_K
#V = input @ W_V
#Attention_Weights = softmax(Q @ K^T / sqrt(d_k))
#Output = Attention_Weights @ V

#两个矩阵相乘，可以为注意力引入2次，引入一些非线性(not sure)
        
        

## Multi-Head Self-Attention

In [3]:
class MultiHeadAttention(nn.Module):
    def __init__(self, hidden_size: int, num_heads: int, dropout: float = 0.1, bias: bool = True):
        """
        Initialization function for the Multi-Head Attention module.

        Args:
            hidden_size (int): Size of the input and output hidden layers.
            num_heads (int): Number of attention heads.
            dropout (float): Probability of dropout. Default is 0.1.
            bias (bool): Whether to use bias in the linear layers. Default is True.
        """
        super().__init__()
        assert hidden_size % num_heads == 0, "Hidden size must be divisible by the number of heads"

        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads

        self.dropout = nn.Dropout(dropout)

        # Linear layer for computing Q, K, V
        self.qkv_linear = nn.Linear(hidden_size, hidden_size * 3, bias=bias)
        # Linear layer for computing the final output
        self.output_linear = nn.Linear(hidden_size, hidden_size, bias=bias)

    def forward(self, x: Tensor) -> Tensor:
        """
        Forward propagation function for the Multi-Head Attention module.

        Args:
            x (Tensor): Input tensor of shape (batch_size, seq_len, hidden_size).

        Returns:
            Tensor: Output tensor of shape (batch_size, seq_len, hidden_size).
        """
        batch_size, seq_len, _ = x.shape

        # Compute Q, K, V
        qkv = self.qkv_linear(x)
        qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # (3, batch_size, num_heads, seq_len, head_dim)
        q, k, v = qkv.unbind(0)

        # Compute attention weights
        attn_weights = torch.matmul(q, k.transpose(-2, -1))  # (batch_size, num_heads, seq_len, seq_len)
        attn_weights = attn_weights / math.sqrt(self.head_dim)
        attn_weights = torch.softmax(attn_weights, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Compute attention output
        attn_output = torch.matmul(attn_weights, v)  # (batch_size, num_heads, seq_len, head_dim)
        attn_output = attn_output.transpose(1, 2)  # (batch_size, seq_len, num_heads, head_dim)
        attn_output = attn_output.reshape(batch_size, seq_len, self.hidden_size)

        # Compute final output
        output = self.output_linear(attn_output)

        return output

#dropout both the attention weights and the final layer, with a default dropout probability of 10 percent

## Bidirectional  Attention

In [4]:
class BidirectionalAttention(nn.Module):
    def __init__(self, hidden_size: int, num_heads: int, dropout: float = 0.1, bias: bool = True):
        super().__init__()
        assert hidden_size % num_heads == 0, "Hidden size must be divisible by the number of heads"
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads

        self.q_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
        self.k_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
        self.v_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
        self.out_proj = nn.Linear(hidden_size, hidden_size, bias=bias)

        self.attn_dropout = nn.Dropout(dropout)
        self.proj_dropout = nn.Dropout(dropout)

    def forward(self, x: Tensor, mask: BoolTensor = None) -> Tensor:
        batch_size, seq_length, hidden_size = x.size()

        q = self.q_proj(x).view(batch_size, seq_length, self.num_heads, self.head_dim)
        k = self.k_proj(x).view(batch_size, seq_length, self.num_heads, self.head_dim)
        v = self.v_proj(x).view(batch_size, seq_length, self.num_heads, self.head_dim)

        attn_scores = torch.einsum("bqhd,bkhd->bhqk", [q, k]) / math.sqrt(self.head_dim)

        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask.unsqueeze(1).unsqueeze(1), float("-inf"))

        attn_probs = self.attn_dropout(F.softmax(attn_scores, dim=-1))
        attn_output = torch.einsum("bhqv,bqhd->bqhd", [attn_probs, v])
        attn_output = attn_output.contiguous().view(batch_size, seq_length, hidden_size)
        output = self.proj_dropout(self.out_proj(attn_output))

        return output
        
    

In [5]:
class CausalAttention(nn.Module):
    def __init__(self, hidden_size: int, num_heads: int, context_size: int, dropout: float = 0.1, bias: bool = True):
        super().__init__()
        assert hidden_size % num_heads == 0, "Hidden size must be divisible by the number of heads"
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads

        self.q_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
        self.k_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
        self.v_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
        self.out_proj = nn.Linear(hidden_size, hidden_size, bias=bias)

        self.dropout = nn.Dropout(dropout)

        self.register_buffer("causal_mask", torch.triu(torch.ones(context_size, context_size, dtype=torch.bool), diagonal=1))

    def forward(self, x: Tensor, mask: BoolTensor = None) -> Tensor:
        batch_size, seq_length, hidden_size = x.size()

        q = self.q_proj(x).view(batch_size, seq_length, self.num_heads, self.head_dim)
        k = self.k_proj(x).view(batch_size, seq_length, self.num_heads, self.head_dim)
        v = self.v_proj(x).view(batch_size, seq_length, self.num_heads, self.head_dim)

        attn_scores = torch.einsum("bqhd,bkhd->bhqk", [q, k]) / math.sqrt(self.head_dim)

        causal_mask = self.causal_mask[:seq_length, :seq_length].unsqueeze(0).unsqueeze(1)

        if mask is not None:
            attn_scores = attn_scores.masked_fill(causal_mask | mask.unsqueeze(1).unsqueeze(1), float("-inf"))
        else:
            attn_scores = attn_scores.masked_fill(causal_mask, float("-inf"))

        attn_probs = self.dropout(F.softmax(attn_scores, dim=-1))
        attn_output = torch.einsum("bhqv,bqhd->bqhd", [attn_probs, v])
        attn_output = attn_output.contiguous().view(batch_size, seq_length, hidden_size)
        output = self.out_proj(attn_output)

        return output

In [6]:
class CausalCrossAttention(nn.Module):
    def __init__(self, hidden_size: int, num_heads: int, context_size: int, dropout: float = 0.1, bias: bool = True):
        super().__init__()
        assert hidden_size % num_heads == 0, "Hidden size must be divisible by the number of heads"
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads

        self.q_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
        self.kv_proj = nn.Linear(hidden_size, hidden_size * 2, bias=bias)
        self.out_proj = nn.Linear(hidden_size, hidden_size, bias=bias)

        self.dropout = nn.Dropout(dropout)

        self.register_buffer("causal_mask", torch.triu(torch.ones(context_size, context_size, dtype=torch.bool), diagonal=1))

    def forward(self, x: Tensor, y: Tensor, mask: BoolTensor = None) -> Tensor:
        batch_size, seq_length, hidden_size = x.size()

        q = self.q_proj(x).view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
        k, v = self.kv_proj(y).view(batch_size, seq_length, 2, self.num_heads, self.head_dim).unbind(dim=2)

        attn_scores = torch.einsum("bnqd,bnkd->bnqk", [q, k]) / math.sqrt(self.head_dim)

        causal_mask = self.causal_mask[:seq_length, :seq_length].unsqueeze(0).unsqueeze(1)

        if mask is not None:
            attn_scores = attn_scores.masked_fill(causal_mask | mask.unsqueeze(1).unsqueeze(1), float("-inf"))
        else:
            attn_scores = attn_scores.masked_fill(causal_mask, float("-inf"))

        attn_probs = self.dropout(F.softmax(attn_scores, dim=-1))
        attn_output = torch.einsum("bnqv,bnqd->bnqd", [attn_probs, v]) #change the order
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_length, hidden_size)
        output = self.out_proj(attn_output)

        return output

## Feed Forward Network

two takeaways: 1. polysemantic; 2. superposition hypothesis

FNN = Act(XW1)W2

FNN operates on each token independently of all other tokens in the sequance.

It cannot reference other tokens or positional information outside of the information embeded in the current token vecto

X --> high dimension  --> original dimention

In [7]:
class FeedForward(nn.Module):
    def __init__(self, hidden_size: int, expand_size: int, dropout: float = 0.1, bias: bool = True):
        super().__init__()
        # Project input to expanded dimension
        self.input_projection = nn.Linear(hidden_size, expand_size, bias=bias)
        # Activation function to introduce non-linearity
        self.activation = nn.GELU()
        # Project back to the input dimension
        self.output_projection = nn.Linear(expand_size, hidden_size, bias=bias)
        # Optional dropout layer to prevent overfitting
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: Tensor) -> Tensor:
        # Apply input projection
        output = self.input_projection(x)
        # Apply activation function
        output = self.activation(output)
        # Apply output projection
        output = self.output_projection(output)
        # Optionally apply dropout layer
        output = self.dropout(output)
        return output

## Transformer Block

Post-Norm can suffer from gradient vanishing as normalization is applied to the outputs of initial layers multiple times.

this can cause the gradient norm to become exponentially small which hinders model training.Using small learning rates and learning rate warmup improves Post-Norm training.

pre-Norm applies the normalization layer to the input before it's passed to the Attention and Feed Forward layers.

In [None]:
class TransformerBlock(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        context_size: int,
        expand_size: int,
        attention: nn.Module = CausalAttention,
        dropout: float = 0.1,
        bias: bool = True,
    ):
        super().__init__()
        # Layer normalization before attention
        self.attn_norm = nn.LayerNorm(hidden_size)
        # Attention layer
        self.attn = attention(
            hidden_size=hidden_size,
            num_heads=num_heads,
            context_size=context_size,
            dropout=dropout,
            bias=bias,
        )
        # Layer normalization before feed-forward, pre-norm
        self.ffn_norm = nn.LayerNorm(hidden_size)
        # Feed-forward layer
        self.ffn = FeedForward(
            hidden_size=hidden_size,
            expand_size=expand_size,
            dropout=dropout,
            bias=bias,
        )

    def forward(self, input: Tensor) -> Tensor:
        # Residual connection for attention
        attn_output = self.attn_norm(input)
        attn_output = input + self.attn(attn_output)

        # Residual connection for feed-forward
        ffn_output = self.ffn_norm(attn_output)
        ffn_output = attn_output + self.ffn(ffn_output)

        return ffn_output