In [1]:
import math
import torch
from torch import nn

In [2]:
from dataclasses import dataclass


@dataclass
class LlamaConfig:
    vocab_size: int  # (토큰 id 갯수)
    hidden_size: int # (토큰당 임베딩 차원 크기)
    intermediate_size: int # (MLP 막 거친 차원 크기)
    num_hidden_layers: int # (디코딩 레이어 갯수)
    attention_heads: int # Query 헤드
    key_value_heads: int # KV 헤드
    hidden_act: str # (FFN 활성화 함수)
    max_position_embeddings: int #(시퀀스 길이)

In [3]:
class RMSNorm(nn.Module):
    def __init__(self, hidden_size: int, eps: float = 1e-6):
        super(RMSNorm, self).__init__()
        self.gamma = nn.Parameter(torch.ones(hidden_size))
        self.eps = eps
    
    def forward(self, x: torch.Tensor):
        rms_x = x.square().mean(dim=-1, keepdim=True).sqrt()
        return self.gamma * x / (rms_x + self.eps)

In [4]:
class RoPE(nn.Module):
    def __init__(self, hidden_size: int):
        super(RoPE, self).__init__()
        d_indices = torch.arange(0, hidden_size / 2, dtype=torch.float32)  # size: hidden_size / 2
        theta = torch.exp(d_indices * -2 * math.log(10000) / hidden_size)  # 10000^(-2i/d) == exp((-2i/d) * log(10000))
        
        self.register_buffer('theta', theta) # device 연동
        
    def forward(self, x: torch.Tensor):  # x: (n, seq_len, hidden_size)
        seq_len = x.size(-2)
        m_theta = torch.matmul(torch.arange(seq_len).float().view(-1, 1), self.theta.view(1, -1))
        m_theta = m_theta.repeat_interleave(2, dim=-1) # (1, 2, 3) -> (1, 1, 2, 2, 3, 3)

        x_flip = torch.stack([-x[:, :, 1::2], x[:, :, ::2], ], dim=-1).flatten(start_dim=-2)  # -x2, x1, -x4, x3 ...
        rotated_x = torch.mul(x, torch.cos(m_theta)) + torch.mul(x_flip, torch.sin(m_theta))
        return rotated_x

In [None]:
class RoPESelfAttention(nn.Module):
    def __init__(self, config: LlamaConfig):
        super(RoPESelfAttention, self).__init__()

        self.num_attention_heads = config.attention_heads
        self.head_size = int(config.hidden_size / self.num_attention_heads)
        
        self.query = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
        self.key = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
        self.value = nn.Linear(config.hidden_size, config.hidden_size)

        self.rms_norm = RMSNorm(config.hidden_size) # 2개?
        self.rope = RoPE(config.hidden_size)
    
    def to_multi_heads(self, x: torch.Tensor):
        n, seq_len, _ = x.size()
        mh = x.view(n, seq_len, self.num_attention_heads, self.head_size)
        return mh.permute(0, 2, 1, 3)

    def forward(self, x: torch.Tensor):  # x: (n, seq_len, hidden_size)
        x_norm = self.rms_norm(x)
        q = self.rope(self.query(x_norm))
        k = self.rope(self.key(x_norm))
        v = self.value(x_norm)

        q = self.to_multi_heads(q) # q, k, v (n, num_heads, seq_len, hidden_size)
        k = self.to_multi_heads(k)
        v = self.to_multi_heads(v)

        atten_score = torch.matmul(q, k.transpose(-1, -2))
        
        # head_size scaling
        
        # add attention mask
        
        # get probs
        attention_probs = torch.softmax(atten_score, dim=-1)
        
        # dropout
        
        context = torch.matmul(attention_probs, v)
        context = context.permute(0, 2, 1, 3).contiguous().view(x.size()) # why contiguous?
        return x + context
        

In [None]:
class SwiGLU(nn.Module):
    ...

In [None]:
class Llama2Decoder(nn.Module):
    def __init__(self, config: LlamaConfig):
        super(Llama2Decoder, self).__init__()
        self.attn = RoPESelfAttention(config)
        self.rms_norm = RMSNorm(config.hidden_size)
        self.swiglu = SwiGLU()
    
    def forward(self, x: torch.Tensor):
        context = self.attn(x)
        context_norm = self.rms_norm(context)


        # swiglu

        # ff

        # add x

        # return output
        
    ...

In [None]:
class Llama2Model(nn.Module):
    ...

In [None]:
batch_size = 32
hidden_size = 256
seq_len = 20
x = torch.randn(batch_size, seq_len, hidden_size)
x.shape

In [63]:
config = LlamaConfig(1, hidden_size, 1, 1, 8, 1, "", 1)
attn = RoPESelfAttention(config)

In [64]:
res = attn(x)

In [66]:
res.shape

torch.Size([32, 20, 256])