nn.Linear 내부에서 일어나는 일

1. 입력 flatten
- PyTorch F.linear는 입력을 자동으로 view(‑1, in_dim) 으로 바꿈
- 즉 x를 (B·I·J, in_dim) 행렬로 본 뒤 연산

2. 행렬 곱	
- y_flat = x_flat @ Wᵀ + b
- W : (dim*3, in_dim)
- b : (dim*3,)
3. 원래 shape 복원	
- 계산된 (B·I·J, dim*3)을 .view(B, I, J, dim*3) 로 되돌림


Pytorch Attention
- 최신버전엔 flash attention 기능이 들어있음

In [1]:
import torch
import torch.nn.functional as F

In [None]:
from packaging import version
from functools import partial
torch_version = version.parse(torch.__version__)

print(torch_version)
sdp_kwargs = dict(
        enable_flash = True,
        enable_math = True,
        enable_mem_efficient = True
)

if torch_version >= version.parse('2.3'):
    from torch.nn.attention import SDPBackend
    str_to_backend = dict(
        enable_flash = SDPBackend.FLASH_ATTENTION,
        enable_mem_efficient = SDPBackend.EFFICIENT_ATTENTION,
        enable_math = SDPBackend.MATH,
        enable_cudnn = SDPBackend.CUDNN_ATTENTION
    )

    sdpa_backends = [str_to_backend[enable_str] for enable_str, enable in sdp_kwargs.items() if enable]
    # 최신 버전 : 여러가지 백엔드 활성화를 시킬수 있음
    sdp_context_manager = partial(torch.nn.attention.sdpa_kernel, sdpa_backends)
else: # 폐기 예정
    sdp_context_manager = partial(torch.backends.cuda.sdp_kernel, **sdp_kwargs)

2.5.1+cu124


In [None]:
# B, H, T, C/H
q = torch.randn(2, 8, 16, 64, device="cuda", dtype=torch.float16)
k = torch.randn(2, 8, 16, 64, device="cuda", dtype=torch.float16)
v = torch.randn(2, 8, 16, 64, device="cuda", dtype=torch.float16)
mask = None

In [10]:
%%timeit
with sdp_context_manager():
    out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.1)

48.9 µs ± 503 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [6]:
out.shape

torch.Size([2, 8, 16, 64])

In [11]:
str_to_backend = dict(
    enable_flash = SDPBackend.FLASH_ATTENTION,
    enable_mem_efficient = SDPBackend.EFFICIENT_ATTENTION,
    enable_math = SDPBackend.MATH,
    enable_cudnn = SDPBackend.CUDNN_ATTENTION
)
sdp_context_manager = partial(torch.backends.cuda.sdp_kernel, **sdp_kwargs)


In [12]:
%%timeit
with sdp_context_manager():
    out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.1)

  self.gen = func(*args, **kwds)


61.3 µs ± 1.27 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [None]:
import math
class CausalSelfAttention(nn.Module):
    def __init__(self, cfg, device):
        super().__init__()
        self.cfg = cfg
        assert cfg.n_emb % cfg.n_heads == 0, "n_emb must be divisible by n_heads"

        # Q, K, V Linear 한꺼번에 
        self.c_attn = nn.Linear(cfg.n_emb, 3 * cfg.n_emb, bias=cfg.bias)

        # Output Linear
        self.c_proj = nn.Linear(cfg.n_emb, cfg.n_emb, bias=cfg.bias)

        # regularization
        self.attn_dropout = nn.Dropout(cfg.dropout)
        self.residual_dropout = nn.Dropout(cfg.dropout)
    
        self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
        if not cfg.flash:
            self.register_buffer(
                    "bias",
                    torch.tril(torch.ones(1, 1, cfg.block_size, cfg.block_size, device=device))
                )
    def forward(self, x):
        B, T, C = x.shape # (batch, length, n_emb)
        # (batch, length, n_emb) -> (batch, length, 3 * n_emb) 를 만든 뒤 3개로 나누기
        q, k, v = self.c_attn(x).split(self.cfg.n_emb, dim=2)
        k = k.view(B, T, self.cfg.n_head, C // self.cfg.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = q.view(B, T, self.cfg.n_head, C // self.cfg.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.cfg.n_head, C // self.cfg.n_head).transpose(1, 2) # (B, nh, T, hs)

        # causal self-attention
        if not self.flash and self.cfg.einops:
            # QK^T 계산
            attn_weights = torch.einsum("b h t d, b h s d -> b h t s", q, k) * (self.head_dim ** -0.5)

            # causal mask
            attn_weights = attn_weights.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf"))
            # softmax 및 dtype 안정성
            attn_weights = F.softmax(attn_weights.float(), dim=-1).type_as(attn_weights)
            attn_weights = self.attn_dropout(attn_weights)

            # attention 곱하기 V
            y = torch.einsum("b h t s, b h s d -> b h t d", attn_weights, v)
        elif not self.flash and self.cfg.manual_attn:
            # manual implementation of attention
            # (q @ k.transpose(-2, -1))를 하면 (B, H, T, C/H) x (B, H, C/H, T) -> (B, H, T, T)
            att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
            mask = self.bias[:,:,:T,:T] == 0 # boolean mask가 나옴 (위쪽이 True인 삼각행렬)
            att = att.masked_fill(mask, float('-inf')) # True인 부분에 -inf를 넣음
            # softmax는 높은 값에 높은 확률을 부여
            # 아주 낮은 값이므로 0의 확률이 되게끔
            att = F.softmax(att, dim=-1) # 여전히 (B, H, T, T)
            att = self.attn_dropout(att)
            y = att @ v # (B, H, T, T) x (B, H, T, C/H) -> (B, H, T, C/H)
        else:
            y = F.scaled_dot_product_attention(
                q, k, v, attn_mask=None, dropout_p=self.cfg.dropout if self.training else 0,
                is_causal=True
            )
        # (B, H, T, C/H) -> (B, T, H, C/H) -> (B, T, C)
        y = y.transpose(1, 2).contiguous().view(B, T, C)

        # output projection
        y = self.residual_dropout(self.c_proj(y))
        return y

Linear attention

- q, k에 kernel 함수를 먼저 적용

![Image](https://img1.daumcdn.net/thumb/R1280x0/?scode=mtistory2&fname=https%3A%2F%2Fblog.kakaocdn.net%2Fdn%2Fcqm9dE%2FbtsshkDxDp5%2F0bouLgFTM4jwSnSr2Ngq40%2Fimg.png)

In [None]:
def rotate_half(x):
    

class LlamaAttention(nn.Module):
    def __init__(self, cfg, layer_idx: int):
        super().__init__()
        self.cfg = cfg
        self.head_dim = getattr(cfg, "head_dim", cfg.hidden_size // cfg.n_heads)
        self.scailing = self.head_dim ** -0.5
        self.attention_dropout = cfg.attention_dropout

        # Causal masking
        # 
        self.is_causal = True

        self.q_proj = nn.Linear(
            cfg.hidden_size, cfg.num_attention_heads * self.head_dim, bias=cfg.attention_bias
        )

        # K, V는 Grouped-Query Attention(GQA)
        self.k_proj = nn.Linear(
            cfg.hidden_size, cfg.num_key_value_heads * self.head_dim, bias=cfg.attention_bias
        )
        self.v_proj = nn.Linear(
            cfg.hidden_size, cfg.num_key_value_heads * self.head_dim, bias=cfg.attention_bias
        )

        self.o_proj = nn.Linear(
            cfg.num_key_value_heads * self.head_dim, cfg.hidden_size, bias=cfg.attention_bias
        )

    def forward(
            self, 
            hidden_states: torch.Tensor,
            position_embeddings: Tuple[torch.Tensor, torch.Tensor],
        
    ):
        # 이렇게 한 이유는 K, V의 num_heads가 달라서 -1로 자동으로 처리하게 하려고
        input_shape = hidden_states.shape[:-1] # [B, T]
        hidden_shape = (*input_shape, -1, self.head_dim)

        # [B, T, C] -> [B, T, H, C/H] -> [B, H, T, C/H]
        query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

        cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)



        

Axial Attention

- 한 번에 한 축(axis) 씩만 어텐션을 계산
- 복잡도를 O(∑a∈axes𝐿𝑎2)로 줄임


언제 유용할까?

- 이미지 패치 (H × W)
- 비디오 (T, H, W) :  시간·공간 축을 번갈아가면서 하기
- Protein/RNA pair matrix

In [None]:
import torch, torch.nn as nn
from eionops import rearrange

class AxialAttention2D(nn.Module):
    def __init__(self, dim, heads=8, dropout=0.1):
        super().__init__()
        self.heads = heads
        self.dim = dim
        self.hattn = nn.MultiheadAttention(dim, heads, dropout=dropout, batch_first=True)
        self.wattn = nn.MultiheadAttention(dim, heads, dropout=dropout, batch_first=True)
        self.dropout = nn.Dropout(dropout)
        self.ln = nn.LayerNorm(dim)

    def forward(self, x):
        """
        x : (B, H, W, C)
        returns : (B, H, W, C)
        """
        b, h, w, c = x.shape
        out = x

        # ① Height‑axis attention  -----------------------------------------
        #   reshape: each column (W) becomes a "sequence" of length H
        seq_h = rearrange(out, 'b h w c -> (b w) h c')
        ah, _ = self.hattn(seq_h, seq_h, seq_h)
        ah = rearrange(ah, '(b w) h c -> b h w c', b=b, w=w)
        out = out + self.dropout(ah)
        out = self.ln(out)

        # ② Width‑axis attention   -----------------------------------------
        #   now each row (H) becomes a "sequence" of length W
        seq_w = rearrange(out, 'b h w c -> (b h)')
        aw, _ = self.wattn(seq_w, seq_w, seq_w)
        aw = rearrange(aw, '(b h) w c -> b h w c', b=b, h=h)
        out = out + self.dropout(aw)
        out = self.ln(out)

        return out
        

In [None]:
class LlamaAttention(nn.Module):
    def __init__(self, config, layer_idx):
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx
        
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.hidden_size // self.num_heads
        self.num_key_value_heads = config.num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        self.max_position_embeddings = config.max_position_embeddings
        self.rope_theta = config.rope_theta
        self.is_causal = True
        

        if (head_dim * num_heads) != hidden_size:
            raise ValueError(f"not divisible")
        
        self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, self.num_heads, bias=config.attention_bias)
        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)


        self._init_rope()
        
    def _init_rope(self):
        self.rotary_emb = ...

    def forward(self, hidden_states : torch.Tensor,
                attention_mask: Optional[torch.Tensor] = None,
                position_ids: Optional[torch.LongTensor] = None,
                past_key_value: Optional[Cache] = None,
                output_attentions: bool = False,
                use_cache: bool = False,
                **kwargs,) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:

        bsz, q_len, _ = hidden_states.size()
        