Pytorch Attention
- 최신버전엔 flash attention 기능이 들어있음

In [1]:
import torch
import torch.nn.functional as F

In [None]:
from packaging import version
from functools import partial
torch_version = version.parse(torch.__version__)

print(torch_version)
sdp_kwargs = dict(
        enable_flash = True,
        enable_math = True,
        enable_mem_efficient = True
)

if torch_version >= version.parse('2.3'):
    from torch.nn.attention import SDPBackend
    str_to_backend = dict(
        enable_flash = SDPBackend.FLASH_ATTENTION,
        enable_mem_efficient = SDPBackend.EFFICIENT_ATTENTION,
        enable_math = SDPBackend.MATH,
        enable_cudnn = SDPBackend.CUDNN_ATTENTION
    )

    sdpa_backends = [str_to_backend[enable_str] for enable_str, enable in sdp_kwargs.items() if enable]
    # 최신 버전 : 여러가지 백엔드 활성화를 시킬수 있음
    sdp_context_manager = partial(torch.nn.attention.sdpa_kernel, sdpa_backends)
else: # 폐기 예정
    sdp_context_manager = partial(torch.backends.cuda.sdp_kernel, **sdp_kwargs)

2.5.1+cu124


In [4]:
q = torch.randn(2, 8, 16, 64, device="cuda", dtype=torch.float16)
k = torch.randn(2, 8, 16, 64, device="cuda", dtype=torch.float16)
v = torch.randn(2, 8, 16, 64, device="cuda", dtype=torch.float16)
mask = None

In [10]:
%%timeit
with sdp_context_manager():
    out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.1)

48.9 µs ± 503 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [6]:
out.shape

torch.Size([2, 8, 16, 64])

In [11]:
str_to_backend = dict(
    enable_flash = SDPBackend.FLASH_ATTENTION,
    enable_mem_efficient = SDPBackend.EFFICIENT_ATTENTION,
    enable_math = SDPBackend.MATH,
    enable_cudnn = SDPBackend.CUDNN_ATTENTION
)
sdp_context_manager = partial(torch.backends.cuda.sdp_kernel, **sdp_kwargs)


In [12]:
%%timeit
with sdp_context_manager():
    out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.1)

  self.gen = func(*args, **kwds)


61.3 µs ± 1.27 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [None]:
import math
class CausalSelfAttention(nn.Module):
    def __init__(self, cfg, device):
        super().__init__()
        self.cfg = cfg
        assert cfg.n_emb % cfg.n_heads == 0, "n_emb must be divisible by n_heads"

        # Q, K, V Linear 한꺼번에 
        self.c_attn = nn.Linear(cfg.n_emb, 3 * cfg.n_emb, bias=cfg.bias)

        # Output Linear
        self.c_proj = nn.Linear(cfg.n_emb, cfg.n_emb, bias=cfg.bias)

        # regularization
        self.attn_dropout = nn.Dropout(cfg.dropout)
        self.residual_dropout = nn.Dropout(cfg.dropout)
    
        self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
        if not cfg.flash:
            self.register_buffer(
                    "bias",
                    torch.tril(torch.ones(1, 1, cfg.block_size, cfg.block_size, device=device))
                )
    def forward(self, x):
        B, T, C = x.shape # (batch, length, n_emb)
        # (batch, length, n_emb) -> (batch, length, 3 * n_emb) 를 만든 뒤 3개로 나누기
        q, k, v = self.c_attn(x).split(self.cfg.n_emb, dim=2)
        k = k.view(B, T, self.cfg.n_head, C // self.cfg.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = q.view(B, T, self.cfg.n_head, C // self.cfg.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.cfg.n_head, C // self.cfg.n_head).transpose(1, 2) # (B, nh, T, hs)

        # causal self-attention
        if not self.flash and self.cfg.einops:
            # QK^T 계산
            attn_weights = torch.einsum("b h t d, b h s d -> b h t s", q, k) * (self.head_dim ** -0.5)

            # causal mask
            attn_weights = attn_weights.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf"))
            # softmax 및 dtype 안정성
            attn_weights = F.softmax(attn_weights.float(), dim=-1).type_as(attn_weights)
            attn_weights = self.attn_dropout(attn_weights)

            # attention 곱하기 V
            y = torch.einsum("b h t s, b h s d -> b h t d", attn_weights, v)
        elif not self.flash and self.cfg.manual_attn:
            # manual implementation of attention
            # (q @ k.transpose(-2, -1))를 하면 (B, H, T, C/H) x (B, H, C/H, T) -> (B, H, T, T)
            att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
            mask = self.bias[:,:,:T,:T] == 0 # boolean mask가 나옴 (위쪽이 True인 삼각행렬)
            att = att.masked_fill(mask, float('-inf')) # True인 부분에 -inf를 넣음
            # softmax는 높은 값게 높은 확률을 부여하고 자함
            # 아주 낮은 값이므로 0의 확률이 되게끔
            att = F.softmax(att, dim=-1) # 여전히 (B, H, T, T)
            att = self.attn_dropout(att)
            y = att @ v # (B, H, T, T) x (B, H, T, C/H) -> (B, H, T, C/H)
        else:
            y = F.scaled_dot_product_attention(
                q, k, v, attn_mask=None, dropout_p=self.cfg.dropout if self.training else 0,
                is_causal=True
            )
        # (B, H, T, C/H) -> (B, T, H, C/H) -> (B, T, C)
        y = y.transpose(1, 2).contiguous().view(B, T, C)

        # output projection
        y = self.residual_dropout(self.c_proj(y))
        return y

Linear attention

- q, k에 kernel 함수를 먼저 적용

![Image](https://img1.daumcdn.net/thumb/R1280x0/?scode=mtistory2&fname=https%3A%2F%2Fblog.kakaocdn.net%2Fdn%2Fcqm9dE%2FbtsshkDxDp5%2F0bouLgFTM4jwSnSr2Ngq40%2Fimg.png)

In [None]:
LlamaAttention

In [None]:
class LlamaAttention(nn.Module):
    def __init__(self, config, layer_idx):
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx
        
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.hidden_size // self.num_heads
        self.num_key_value_heads = config.num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        self.max_position_embeddings = config.max_position_embeddings
        self.rope_theta = config.rope_theta
        self.is_causal = True
        

        if (head_dim * num_heads) != hidden_size:
            raise ValueError(f"not divisible")
        
        self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, self.num_heads, bias=config.attention_bias)
        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)


        self._init_rope()
        
    def _init_rope(self):
        self.rotary_emb = ...

    def forward(self, hidden_states : torch.Tensor,
                attention_mask: Optional[torch.Tensor] = None,
                position_ids: Optional[torch.LongTensor] = None,
                past_key_value: Optional[Cache] = None,
                output_attentions: bool = False,
                use_cache: bool = False,
                **kwargs,) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:

        bsz, q_len, _ = hidden_states.size()
        