In [1]:
import math
import torch
from torch import nn

In [2]:
from dataclasses import dataclass


@dataclass
class LlamaConfig:
    vocab_size: int  # (토큰 id 갯수)
    hidden_size: int # (토큰당 임베딩 차원 크기)
    intermediate_size: int # (MLP 막 거친 차원 크기)
    num_hidden_layers: int # (디코딩 레이어 갯수)
    attention_heads: int # Query 헤드
    key_value_heads: int # KV 헤드
    hidden_act: str # (FFN 활성화 함수)
    max_position_embeddings: int #(시퀀스 길이)

In [3]:
class RMSNorm(nn.Module):
    def __init__(self, hidden_size: int, eps: float = 1e-6):
        super(RMSNorm, self).__init__()
        self.gamma = nn.Parameter(torch.ones(hidden_size))
        self.eps = eps
    
    def forward(self, x: torch.Tensor):
        rms_x = x.square().mean(dim=-1, keepdim=True).sqrt()
        return self.gamma * x / (rms_x + self.eps)

In [4]:
class RoPE(nn.Module):
    def __init__(self, hidden_size: int):
        super(RoPE, self).__init__()
        d_indices = torch.arange(0, hidden_size / 2, dtype=torch.float32)  # size: hidden_size / 2
        theta = torch.exp(d_indices * -2 * math.log(10000) / hidden_size)  # 10000^(-2i/d) == exp((-2i/d) * log(10000))
        
        self.register_buffer('theta', theta) # device 연동
        
    def forward(self, x: torch.Tensor):  # x: (n, seq_len, hidden_size)
        seq_len = x.size(-2)
        m_theta = torch.matmul(torch.arange(seq_len).float().view(-1, 1), self.theta.view(1, -1))
        m_theta = m_theta.repeat_interleave(2, dim=-1) # (1, 2, 3) -> (1, 1, 2, 2, 3, 3)

        x_flip = torch.stack([-x[:, :, 1::2], x[:, :, ::2], ], dim=-1).flatten(start_dim=-2)  # -x2, x1, -x4, x3 ...
        rotated_x = torch.mul(x, torch.cos(m_theta)) + torch.mul(x_flip, torch.sin(m_theta))
        return rotated_x

In [5]:
batch_size = 32
hidden_size = 256
seq_len = 20
x = torch.randn(batch_size, seq_len, hidden_size)
x.shape

torch.Size([32, 20, 256])

In [None]:
rope = RoPE(hidden_size=hidden_size)

In [None]:
rope(x).shape

torch.Size([32, 20, 256])

In [None]:
m_theta = torch.matmul(torch.arange(seq_len).float().view(-1, 1), rope.theta.view(1, -1))
m_theta = m_theta.repeat_interleave(2, dim=-1)
m_theta.shape

torch.Size([20, 256])

In [277]:
m_theta

tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 1.0000e+00, 9.3057e-01,  ..., 1.1548e-04, 1.0746e-04,
         1.0746e-04],
        [2.0000e+00, 2.0000e+00, 1.8611e+00,  ..., 2.3096e-04, 2.1492e-04,
         2.1492e-04],
        ...,
        [1.7000e+01, 1.7000e+01, 1.5820e+01,  ..., 1.9631e-03, 1.8268e-03,
         1.8268e-03],
        [1.8000e+01, 1.8000e+01, 1.6750e+01,  ..., 2.0786e-03, 1.9343e-03,
         1.9343e-03],
        [1.9000e+01, 1.9000e+01, 1.7681e+01,  ..., 2.1941e-03, 2.0418e-03,
         2.0418e-03]])

In [263]:
torch.mul(x, m_theta).shape

torch.Size([32, 20, 256])

In [None]:
torch.matmul(torch.arange(1, 6).float().view(5, 1), torch.ones(5).view(1, 5))

tensor([[1., 1., 1., 1., 1.],
        [2., 2., 2., 2., 2.],
        [3., 3., 3., 3., 3.],
        [4., 4., 4., 4., 4.],
        [5., 5., 5., 5., 5.]])

In [271]:
dummy = torch.stack([torch.stack([torch.arange(1, hidden_size + 1)] * seq_len)] * batch_size)
dummy_flip = torch.stack([-dummy[:, :, 1::2], dummy[:, :, ::2], ], dim=-1).flatten(start_dim=-2)

In [278]:
dummy

tensor([[[  1,   2,   3,  ..., 254, 255, 256],
         [  1,   2,   3,  ..., 254, 255, 256],
         [  1,   2,   3,  ..., 254, 255, 256],
         ...,
         [  1,   2,   3,  ..., 254, 255, 256],
         [  1,   2,   3,  ..., 254, 255, 256],
         [  1,   2,   3,  ..., 254, 255, 256]],

        [[  1,   2,   3,  ..., 254, 255, 256],
         [  1,   2,   3,  ..., 254, 255, 256],
         [  1,   2,   3,  ..., 254, 255, 256],
         ...,
         [  1,   2,   3,  ..., 254, 255, 256],
         [  1,   2,   3,  ..., 254, 255, 256],
         [  1,   2,   3,  ..., 254, 255, 256]],

        [[  1,   2,   3,  ..., 254, 255, 256],
         [  1,   2,   3,  ..., 254, 255, 256],
         [  1,   2,   3,  ..., 254, 255, 256],
         ...,
         [  1,   2,   3,  ..., 254, 255, 256],
         [  1,   2,   3,  ..., 254, 255, 256],
         [  1,   2,   3,  ..., 254, 255, 256]],

        ...,

        [[  1,   2,   3,  ..., 254, 255, 256],
         [  1,   2,   3,  ..., 254, 255, 256]

In [279]:
torch.mul(dummy, m_theta)

tensor([[[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00],
         [1.0000e+00, 2.0000e+00, 2.7917e+00,  ..., 2.9331e-02,
          2.7402e-02, 2.7510e-02],
         [2.0000e+00, 4.0000e+00, 5.5834e+00,  ..., 5.8663e-02,
          5.4805e-02, 5.5020e-02],
         ...,
         [1.7000e+01, 3.4000e+01, 4.7459e+01,  ..., 4.9863e-01,
          4.6584e-01, 4.6767e-01],
         [1.8000e+01, 3.6000e+01, 5.0251e+01,  ..., 5.2797e-01,
          4.9324e-01, 4.9518e-01],
         [1.9000e+01, 3.8000e+01, 5.3043e+01,  ..., 5.5730e-01,
          5.2065e-01, 5.2269e-01]],

        [[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00],
         [1.0000e+00, 2.0000e+00, 2.7917e+00,  ..., 2.9331e-02,
          2.7402e-02, 2.7510e-02],
         [2.0000e+00, 4.0000e+00, 5.5834e+00,  ..., 5.8663e-02,
          5.4805e-02, 5.5020e-02],
         ...,
         [1.7000e+01, 3.4000e+01, 4.7459e+01,  ..., 4.9863e-01,
          4.658

In [None]:
hidden_size = 256
theta = torch.exp(torch.arange(0, hidden_size, 2, dtype=torch.float32) * -math.log(10000) / hidden_size)

In [97]:
theta

tensor([1.0000e+00, 9.3057e-01, 8.6596e-01, 8.0584e-01, 7.4989e-01, 6.9783e-01,
        6.4938e-01, 6.0430e-01, 5.6234e-01, 5.2330e-01, 4.8697e-01, 4.5316e-01,
        4.2170e-01, 3.9242e-01, 3.6517e-01, 3.3982e-01, 3.1623e-01, 2.9427e-01,
        2.7384e-01, 2.5483e-01, 2.3714e-01, 2.2067e-01, 2.0535e-01, 1.9110e-01,
        1.7783e-01, 1.6548e-01, 1.5399e-01, 1.4330e-01, 1.3335e-01, 1.2409e-01,
        1.1548e-01, 1.0746e-01, 1.0000e-01, 9.3057e-02, 8.6596e-02, 8.0584e-02,
        7.4989e-02, 6.9783e-02, 6.4938e-02, 6.0430e-02, 5.6234e-02, 5.2330e-02,
        4.8697e-02, 4.5316e-02, 4.2170e-02, 3.9242e-02, 3.6517e-02, 3.3982e-02,
        3.1623e-02, 2.9427e-02, 2.7384e-02, 2.5483e-02, 2.3714e-02, 2.2067e-02,
        2.0535e-02, 1.9110e-02, 1.7783e-02, 1.6548e-02, 1.5399e-02, 1.4330e-02,
        1.3335e-02, 1.2409e-02, 1.1548e-02, 1.0746e-02, 1.0000e-02, 9.3057e-03,
        8.6596e-03, 8.0584e-03, 7.4989e-03, 6.9783e-03, 6.4938e-03, 6.0430e-03,
        5.6234e-03, 5.2330e-03, 4.8697e-

In [93]:
torch.cos(theta.repeat_interleave(2))
torch.sin(theta.repeat_interleave(2))

tensor([8.4147e-01, 8.4147e-01, 8.0196e-01, 8.0196e-01, 7.6172e-01, 7.6172e-01,
        7.2141e-01, 7.2141e-01, 6.8156e-01, 6.8156e-01, 6.4256e-01, 6.4256e-01,
        6.0469e-01, 6.0469e-01, 5.6818e-01, 5.6818e-01, 5.3317e-01, 5.3317e-01,
        4.9974e-01, 4.9974e-01, 4.6795e-01, 4.6795e-01, 4.3781e-01, 4.3781e-01,
        4.0931e-01, 4.0931e-01, 3.8242e-01, 3.8242e-01, 3.5711e-01, 3.5711e-01,
        3.3332e-01, 3.3332e-01, 3.1098e-01, 3.1098e-01, 2.9004e-01, 2.9004e-01,
        2.7043e-01, 2.7043e-01, 2.5208e-01, 2.5208e-01, 2.3492e-01, 2.3492e-01,
        2.1889e-01, 2.1889e-01, 2.0391e-01, 2.0391e-01, 1.8993e-01, 1.8993e-01,
        1.7689e-01, 1.7689e-01, 1.6473e-01, 1.6473e-01, 1.5338e-01, 1.5338e-01,
        1.4281e-01, 1.4281e-01, 1.3296e-01, 1.3296e-01, 1.2378e-01, 1.2378e-01,
        1.1522e-01, 1.1522e-01, 1.0725e-01, 1.0725e-01, 9.9833e-02, 9.9833e-02,
        9.2923e-02, 9.2923e-02, 8.6488e-02, 8.6488e-02, 8.0497e-02, 8.0497e-02,
        7.4919e-02, 7.4919e-02, 6.9726e-

In [85]:
torch.exp(torch.arange(0, hidden_size / 2, dtype=torch.float32) * -2 * math.log(10000) / hidden_size)

tensor([1.0000e+00, 9.3057e-01, 8.6596e-01, 8.0584e-01, 7.4989e-01, 6.9783e-01,
        6.4938e-01, 6.0430e-01, 5.6234e-01, 5.2330e-01, 4.8697e-01, 4.5316e-01,
        4.2170e-01, 3.9242e-01, 3.6517e-01, 3.3982e-01, 3.1623e-01, 2.9427e-01,
        2.7384e-01, 2.5483e-01, 2.3714e-01, 2.2067e-01, 2.0535e-01, 1.9110e-01,
        1.7783e-01, 1.6548e-01, 1.5399e-01, 1.4330e-01, 1.3335e-01, 1.2409e-01,
        1.1548e-01, 1.0746e-01, 1.0000e-01, 9.3057e-02, 8.6596e-02, 8.0584e-02,
        7.4989e-02, 6.9783e-02, 6.4938e-02, 6.0430e-02, 5.6234e-02, 5.2330e-02,
        4.8697e-02, 4.5316e-02, 4.2170e-02, 3.9242e-02, 3.6517e-02, 3.3982e-02,
        3.1623e-02, 2.9427e-02, 2.7384e-02, 2.5483e-02, 2.3714e-02, 2.2067e-02,
        2.0535e-02, 1.9110e-02, 1.7783e-02, 1.6548e-02, 1.5399e-02, 1.4330e-02,
        1.3335e-02, 1.2409e-02, 1.1548e-02, 1.0746e-02, 1.0000e-02, 9.3057e-03,
        8.6596e-03, 8.0584e-03, 7.4989e-03, 6.9783e-03, 6.4938e-03, 6.0430e-03,
        5.6234e-03, 5.2330e-03, 4.8697e-

In [None]:
class RoPESelfAttention(nn.Module):
    def __init__(self, config: LlamaConfig):
        super(RoPESelfAttention, self).__init__()

        self.num_attention_heads = config.attention_heads
        self.head_size = int(config.hidden_size / self.num_attention_heads)
        
        self.query = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
        self.key = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
        self.value = nn.Linear(config.hidden_size, config.hidden_size)

        self.rms_norm = RMSNorm(config.hidden_size) # 2개?
        self.rope = RoPE(config.hidden_size)
    
    def to_multi_heads(self, x: torch.Tensor):
        n, seq_len, _ = x.size()
        mh = x.view(n, seq_len, self.num_attention_heads, self.head_size)
        return mh.permute(0, 2, 1, 3)

    def forward(self, x: torch.Tensor):  # x: (n, seq_len, hidden_size)
        x_norm = self.rms_norm(x)
        q = self.rope(self.query(x_norm))
        k = self.rope(self.key(x_norm))

        v = self.value(x_norm)

        q = self.to_multi_heads(q) # q: (n, num_heads, seq_len, hidden_size)
        k = self.to_multi_heads(k)
        v = self.to_multi_heads(v)

        atten_score = torch.matmul(q, k.transpose(-1, -2))
        
        # head_size scaling
        
        # add attention mask
        
        # get probs
        attention_probs = torch.softmax(atten_score, dim=-1)
        
        # dropout

        # matmul value
        context = torch.matmul(attention_probs, v)
        context = context.permute(0, 2, 1, 3).contiguous().view(x.size()) # why contiguous?
        return x + context
        

In [63]:
config = LlamaConfig(1, hidden_size, 1, 1, 8, 1, "", 1)
attn = RoPESelfAttention(config)

In [64]:
res = attn(x)

In [66]:
res.shape

torch.Size([32, 20, 256])

In [52]:
q = nn.Linear(hidden_size, hidden_size)
k = nn.Linear(hidden_size, hidden_size)

q = q(x)
k = k(x)

In [53]:
q_mh = attn.to_multi_heads(q)
k_mh = attn.to_multi_heads(k)
q_mh.shape

torch.Size([32, 8, 20, 32])

In [54]:
k_mh.transpose(-1, -2).shape

torch.Size([32, 8, 32, 20])

In [None]:
attn_score = torch.matmul(q_mh, k_mh.transpose(-1, -2)) 
attn_score.shape # (n, num_heads, seq_len, seq_len) # 단어들의 attn


torch.Size([32, 8, 20, 20])

In [None]:
class Pooler(nn.Module):
    def __init__(self, config: LlamaConfig):
        super(Pooler, self).__init__()
        self.rms_norm = RMSNorm(config.hidden_size)
    
    def forward(self, x: torch.Tensor):
        x_norm = self.rms_norm(x)

        # silu

        # ff

        # add x

        # return output
        

In [7]:
class Llama2Decoder(nn.Module):
    ...

In [None]:
class Llama2Model(nn.Module):
    ...