In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from transformers import AutoModel, AutoConfig
import experiment.metrics.metric as metric
from experiment.metrics.metric import cosine_similarity

In [2]:
""" set default device to mps """

device = torch.device("cuda" if torch.cuda.is_available() else "mps")
device

device(type='mps')

In [3]:
""" LoRA 결과 해석 데모 버전
1) 실제는 DeBERTa 사전 학습 버젼, 파인튜닝 버젼 가중치 행렬에 직접 수행
"""

pretrain_w = torch.randn(768, 768, device=device)
lora_a, lora_b = torch.randn(768, 4, device=device), torch.randn(4, 768, device=device)
delta_w = torch.matmul(lora_a, lora_b)

print(f"lora_a: {lora_a.shape}")
print(f"lora_b: {lora_b.shape}")
print(f"delta_w: {delta_w.shape}")

U, S, V = torch.svd(delta_w)
print(f"U: {U.shape}")
print(f"S: {S.shape}")
print(f"V: {V.shape}")

lora_a: torch.Size([768, 4])
lora_b: torch.Size([4, 768])
delta_w: torch.Size([768, 768])
U: torch.Size([768, 768])
S: torch.Size([768])
V: torch.Size([768, 768])


  U, S, V = torch.svd(delta_w)


In [18]:
""" ||U⊤WqV ⊤|| """

r = 4
r_U, r_V = U[:4, :], V[:, :4]
result = torch.matmul(r_U @ pretrain_w, r_V)
f_norm = torch.norm(result)
result, f_norm

(tensor([[-0.2021, -0.0627,  1.4396, -0.0408],
         [-1.3369, -1.4538, -0.3073, -0.4908],
         [ 0.2128, -1.6067, -0.9427, -0.1693],
         [-0.2723, -0.3725,  1.2293, -0.6210]], device='mps:0'),
 tensor(3.4654, device='mps:0'))

In [19]:
""" LoRA 결과 해석 재현 """
pt_config = AutoConfig.from_pretrained('FacebookAI/roberta-base')
pt_model = AutoModel.from_pretrained(
    'FacebookAI/roberta-base',
    config=pt_config
)

lora_checkpoint = torch.load('model/roberta_base_lora_mrpc.bin', map_location='cpu')
lora_checkpoint

Some weights of the model checkpoint at FacebookAI/roberta-base were not used when initializing RobertaModel: ['lm_head.bias', 'lm_head.layer_norm.bias', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.dense.weight']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaModel were not initialized from the model checkpoint at FacebookAI/roberta-base and are newly initialized: ['roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


OrderedDict([('roberta.encoder.layer.0.attention.self.query.lora_A',
              tensor([[ 0.0144, -0.0180,  0.0811,  ..., -0.0082,  0.0915, -0.0301],
                      [-0.0231, -0.0129, -0.0659,  ...,  0.0245,  0.0474, -0.0645],
                      [-0.0081,  0.0331,  0.0666,  ..., -0.0378,  0.0596, -0.0176],
                      ...,
                      [-0.0296, -0.0362,  0.0024,  ...,  0.0099, -0.0136, -0.0613],
                      [-0.0010, -0.0004,  0.0037,  ..., -0.0238,  0.0650,  0.0728],
                      [ 0.0102, -0.0249, -0.0064,  ...,  0.0540, -0.0471, -0.0616]])),
             ('roberta.encoder.layer.0.attention.self.query.lora_B',
              tensor([[-1.0878e-02, -1.1407e-01,  1.9040e-02,  ..., -5.1958e-03,
                        4.4807e-02, -6.5387e-02],
                      [-1.3314e-02, -3.8988e-02,  1.8823e-02,  ...,  2.2530e-02,
                        1.7791e-02, -1.5717e-02],
                      [ 1.5858e-02,  3.6482e-02, -2.4141e-02,  ...

In [20]:
""" Select Wq in 6-th encoder layer """

pt_wq, lora_a, lora_b = pt_model.encoder.layer[6].attention.self.query.weight, lora_checkpoint['roberta.encoder.layer.6.attention.self.query.lora_A'], lora_checkpoint['roberta.encoder.layer.6.attention.self.query.lora_B']
delta_wq = lora_b @ lora_a
pt_wq.shape, lora_a.shape, lora_b.shape, delta_wq.shape

(torch.Size([768, 768]),
 torch.Size([8, 768]),
 torch.Size([768, 8]),
 torch.Size([768, 768]))

In [21]:
""" Let's SVD, select top-r singular vector """

U, S, V = torch.svd(delta_wq)
print(f"Delta W U: {U.shape}")
print(f"Delta W S: {S.shape}")
print(f"Delta W V: {V.shape}")

r = 4
r_U, r_V = U[:, :r], V[:r, :]
result1 = torch.matmul(r_U.T @ pt_wq, r_V.T)
fwq_norm = torch.norm(result1)
result1, fwq_norm

Delta W U: torch.Size([768, 768])
Delta W S: torch.Size([768])
Delta W V: torch.Size([768, 768])


(tensor([[-0.0441,  0.0447,  0.0323,  0.0963],
         [-0.0038, -0.0412, -0.0903, -0.0949],
         [-0.0314,  0.1003, -0.0599,  0.0023],
         [-0.0222, -0.1090,  0.0315,  0.0575]], grad_fn=<MmBackward0>),
 tensor(0.2539, grad_fn=<LinalgVectorNormBackward0>))

In [22]:
""" ∥∆W ∥F"""
fdwq_norm = torch.norm(delta_wq)
fdwq_norm

tensor(5.0820)

In [23]:
""" Final """

fdwq_norm / fwq_norm

tensor(20.0170, grad_fn=<DivBackward0>)

In [5]:
""" Test code for making rotary embedding more computational efficient """

position = torch.arange(0, 10, dtype=torch.long)
pos_embedding = nn.Embedding(10, 4)
position, pos_embedding

(tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), Embedding(10, 4))

In [7]:
pos_emb = pos_embedding(position)
sin, cos = pos_emb.chunk(2, dim=-1) 
sin, cos

(tensor([[-0.3922,  0.8835],
         [ 0.3392,  0.0947],
         [ 0.3570, -0.6238],
         [-0.1659, -0.4547],
         [-0.8697,  1.9447],
         [-0.0843, -0.2012],
         [ 0.7752, -0.4308],
         [ 2.1914,  0.8687],
         [-0.5303,  0.8186],
         [-1.2912,  1.4390]], grad_fn=<SplitBackward0>),
 tensor([[-0.9374,  0.7161],
         [-0.5622, -0.5999],
         [-0.8200,  0.9056],
         [ 1.4742, -0.0596],
         [-0.2230,  2.0772],
         [ 1.3687,  0.7459],
         [-1.0737, -0.4662],
         [-0.2012, -1.2607],
         [-0.3093, -0.4485],
         [ 0.9633,  1.0118]], grad_fn=<SplitBackward0>))

In [10]:
""" Sine """
sin_pos = torch.stack([sin, sin], dim=-1).reshape_as(pos_emb)
sin_pos, sin_pos.shape

(tensor([[-0.3922, -0.3922,  0.8835,  0.8835],
         [ 0.3392,  0.3392,  0.0947,  0.0947],
         [ 0.3570,  0.3570, -0.6238, -0.6238],
         [-0.1659, -0.1659, -0.4547, -0.4547],
         [-0.8697, -0.8697,  1.9447,  1.9447],
         [-0.0843, -0.0843, -0.2012, -0.2012],
         [ 0.7752,  0.7752, -0.4308, -0.4308],
         [ 2.1914,  2.1914,  0.8687,  0.8687],
         [-0.5303, -0.5303,  0.8186,  0.8186],
         [-1.2912, -1.2912,  1.4390,  1.4390]], grad_fn=<ReshapeAliasBackward0>),
 torch.Size([10, 4]))

In [11]:
""" Cosine """
cos_pos = torch.stack([cos, cos], dim=-1).reshape_as(pos_emb)
cos_pos, cos_pos.shape

(tensor([[-0.9374, -0.9374,  0.7161,  0.7161],
         [-0.5622, -0.5622, -0.5999, -0.5999],
         [-0.8200, -0.8200,  0.9056,  0.9056],
         [ 1.4742,  1.4742, -0.0596, -0.0596],
         [-0.2230, -0.2230,  2.0772,  2.0772],
         [ 1.3687,  1.3687,  0.7459,  0.7459],
         [-1.0737, -1.0737, -0.4662, -0.4662],
         [-0.2012, -0.2012, -1.2607, -1.2607],
         [-0.3093, -0.3093, -0.4485, -0.4485],
         [ 0.9633,  0.9633,  1.0118,  1.0118]], grad_fn=<ReshapeAliasBackward0>),
 torch.Size([10, 4]))

In [17]:
""" Tensor indexing
... => :,:,:
"""

query = torch.rand(3, 5, 4, 8, device=device)
query.shape, query[:,:,:, 1::2].shape  # ... => 모든 차원 select, 시작 위치::스탭 크기

(torch.Size([3, 5, 4, 8]), torch.Size([3, 5, 4, 4]))

In [19]:
test = torch.rand(16, 12, 512, 64, device=device)
test[:,:,:, ::2].shape, test[..., ::2].shape

(torch.Size([16, 12, 512, 32]), torch.Size([16, 12, 512, 32]))

In [7]:
padding_mask = torch.tensor([
    [0,0,0,1,1],
    [0,0,0,0,1],
    [0,0,1,1,1]
]).to(device)
key = torch.randn(3, 5, 4, device=device)
key[padding_mask == 1] = 0
key

tensor([[[-1.5524,  0.6596,  2.4447,  0.2761],
         [ 0.2821,  0.1620,  0.9388, -0.7997],
         [ 0.7599, -0.9758,  0.1258, -2.2957],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000]],

        [[-0.5441,  0.2044, -1.0543, -0.6610],
         [ 0.2436,  0.3609,  1.6111,  0.0072],
         [-0.4589, -0.5709, -0.9771, -0.9881],
         [-0.3361, -0.6272, -0.0786,  1.5555],
         [ 0.0000,  0.0000,  0.0000,  0.0000]],

        [[-0.5229,  0.7038, -0.5714,  2.3070],
         [-0.6425,  0.7309, -1.4076,  0.0123],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000]]], device='cuda:0')

In [17]:
test = torch.rand(512, 768, 768, requires_grad=False, device=device)
element_size, num_elements = test.element_size(), test.numel()
memory_size = element_size * num_elements
memory_size

1207959552

In [4]:
""" Rotary Position Encoding Test
1) integrate with multi-head attention
2) multiply it before separate the heads
최대한 행렬로 처리할 방법을 찾자
dim 0, 1이 같은 공간으로 들어가고, 
dim 2, 3이 같은 공간으로 들어가게

cos m*theta1, -sin m*theta1 0 0 0 0 0 0 0 0 0 0   여기가 그럼 디
sin m*theta1, cos m*theta1  0 0 0 0 0 0 0 0 0 0
0               0


[Rotary Postion Encoding Algorithm]
1) create a scaler tensor
2) make theta tensor
3) multiple theta tensor and scaler => m*theta1
4) linear projection to transformation matrix
5) multiply it with word embedding
5) linear projectino to query, key
    - value matrix does not multiply with transformation matrix
"""
BATCH_SIZE = 2
MAX_SEQ = 4
DIM_MODEL = 6
scaler = torch.arange(1, MAX_SEQ+1, device=device, dtype=torch.float).unsqueeze(1).repeat(1, DIM_MODEL).reshape(MAX_SEQ, DIM_MODEL).unsqueeze(0).repeat(BATCH_SIZE, 1, 1)
print(f"scaler tensor: {scaler, scaler.shape}")

query = torch.rand(BATCH_SIZE, MAX_SEQ, DIM_MODEL, device=device)
test = torch.mul(query, scaler)
query, test, test.shape

scaler tensor: (tensor([[[1., 1., 1., 1., 1., 1.],
         [2., 2., 2., 2., 2., 2.],
         [3., 3., 3., 3., 3., 3.],
         [4., 4., 4., 4., 4., 4.]],

        [[1., 1., 1., 1., 1., 1.],
         [2., 2., 2., 2., 2., 2.],
         [3., 3., 3., 3., 3., 3.],
         [4., 4., 4., 4., 4., 4.]]], device='cuda:0'), torch.Size([2, 4, 6]))


(tensor([[[0.8779, 0.3744, 0.3236, 0.7615, 0.3494, 0.2805],
          [0.5583, 0.4323, 0.7704, 0.2998, 0.6243, 0.1674],
          [0.4235, 0.8064, 0.3026, 0.2402, 0.6097, 0.7895],
          [0.3535, 0.9134, 0.1527, 0.5949, 0.6882, 0.5629]],
 
         [[0.5002, 0.6934, 0.1638, 0.4659, 0.5179, 0.3682],
          [0.0644, 0.6148, 0.5767, 0.6943, 0.1547, 0.4829],
          [0.5885, 0.2067, 0.9069, 0.6666, 0.0275, 0.5214],
          [0.7359, 0.3934, 0.3014, 0.2766, 0.2825, 0.0026]]], device='cuda:0'),
 tensor([[[0.8779, 0.3744, 0.3236, 0.7615, 0.3494, 0.2805],
          [1.1165, 0.8646, 1.5407, 0.5996, 1.2486, 0.3347],
          [1.2704, 2.4192, 0.9077, 0.7206, 1.8291, 2.3685],
          [1.4140, 3.6534, 0.6109, 2.3797, 2.7529, 2.2514]],
 
         [[0.5002, 0.6934, 0.1638, 0.4659, 0.5179, 0.3682],
          [0.1288, 1.2296, 1.1534, 1.3886, 0.3094, 0.9658],
          [1.7656, 0.6201, 2.7208, 1.9999, 0.0825, 1.5643],
          [2.9438, 1.5736, 1.2058, 1.1066, 1.1300, 0.0104]]], device='cuda

In [29]:
""" make theta tensor """
i_arr = torch.arange(1, int(DIM_MODEL/2)+1).to(device)
i_arr = i_arr.repeat_interleave(2)
theta = 10000**(-2*(i_arr - 1)/DIM_MODEL)
theta, theta.shape

(tensor([1.0000, 1.0000, 0.0464, 0.0464, 0.0022, 0.0022], device='mps:0'),
 torch.Size([6]))

In [30]:
""" multiple scaler and theta tensor """
var_tensor = torch.mul(scaler, theta)
var_tensor, var_tensor.shape

(tensor([[[1.0000e+00, 1.0000e+00, 4.6416e-02, 4.6416e-02, 2.1544e-03,
           2.1544e-03],
          [2.0000e+00, 2.0000e+00, 9.2832e-02, 9.2832e-02, 4.3089e-03,
           4.3089e-03],
          [3.0000e+00, 3.0000e+00, 1.3925e-01, 1.3925e-01, 6.4633e-03,
           6.4633e-03],
          [4.0000e+00, 4.0000e+00, 1.8566e-01, 1.8566e-01, 8.6177e-03,
           8.6177e-03]],
 
         [[1.0000e+00, 1.0000e+00, 4.6416e-02, 4.6416e-02, 2.1544e-03,
           2.1544e-03],
          [2.0000e+00, 2.0000e+00, 9.2832e-02, 9.2832e-02, 4.3089e-03,
           4.3089e-03],
          [3.0000e+00, 3.0000e+00, 1.3925e-01, 1.3925e-01, 6.4633e-03,
           6.4633e-03],
          [4.0000e+00, 4.0000e+00, 1.8566e-01, 1.8566e-01, 8.6177e-03,
           8.6177e-03]]], device='mps:0'),
 torch.Size([2, 4, 6]))

In [5]:
i_arr = torch.arange(1, int(DIM_MODEL/2)+1).repeat_interleave(2).to(device)
i_arr

tensor([1, 1, 2, 2, 3, 3], device='cuda:0')

In [27]:
BATCH_SIZE = 2
MAX_SEQ = 4
DIM_MODEL = 6
scaler = torch.arange(1, MAX_SEQ+1, device=device, dtype=torch.float).unsqueeze(1).repeat(1, DIM_MODEL).reshape(MAX_SEQ, DIM_MODEL)
print(scaler)

""" make theta tensor """
i_arr = torch.arange(1, int(DIM_MODEL/2)+1).to(device)
i_arr = i_arr.repeat_interleave(2)
theta = 10000**(-2*(i_arr - 1)/DIM_MODEL)
var_tensor = torch.mul(scaler, theta)


def create_rotation_matrix_v2(d, thetas):
    """
    Create a batch of rotation matrices from the given thetas.
    
    Args:
        d (int): The dimensionality of the rotation matrix (must be even).
        thetas (Tensor): A tensor of shape (batch_size, seq_len, d/2) containing the rotation angles.
    
    Returns:
        Tensor: A tensor of shape (batch_size, seq_len, d, d) containing the rotation matrices.
    
    """
    seq_len, _ = thetas.size()
    R = torch.eye(d, device=thetas.device).repeat(seq_len, 1, 1)
    
    for i in range(0, d, 2):
        cos_t = torch.cos(thetas[:, i]).unsqueeze(-1)
        sin_t = torch.sin(thetas[:, i]).unsqueeze(-1)
    
        R[:, i, i] = cos_t.squeeze(-1)
        R[:, i+1, i+1] = cos_t.squeeze(-1)
        R[:, i, i+1] = -sin_t.squeeze(-1)
        R[:, i+1, i] = sin_t.squeeze(-1)
        
    return R
R = create_rotation_matrix_v2(DIM_MODEL, var_tensor)
query = torch.rand(BATCH_SIZE, MAX_SEQ, DIM_MODEL, device=device)
# test = query.clone().detach()
# for s in range(MAX_SEQ):
#     sub_rotary, sub_word = R[s, :, :], test[s, :]
#     test[s, :] = torch.matmul(sub_rotary, sub_word)

query = query.view(BATCH_SIZE, MAX_SEQ, DIM_MODEL)  # 배치 및 시퀀스 차원을 결합하여 행렬 곱셈을 위한 준비
R = R.view(MAX_SEQ, DIM_MODEL, DIM_MODEL)  # 배치 및 시퀀스 차원을 결합하여 행렬 곱셈을 위한 준비

result = torch.vstack([torch.bmm(R, query[i].unsqueeze(-1)).squeeze(-1).view(MAX_SEQ, DIM_MODEL) for i in range(BATCH_SIZE)]).view(BATCH_SIZE, MAX_SEQ, DIM_MODEL)
result, result.shape

tensor([[1., 1., 1., 1., 1., 1.],
        [2., 2., 2., 2., 2., 2.],
        [3., 3., 3., 3., 3., 3.],
        [4., 4., 4., 4., 4., 4.]], device='cuda:0')


(tensor([[[-0.1132,  0.1565,  0.3930,  0.6889,  0.9187,  0.7205],
          [-1.1905,  0.5498,  0.9004,  0.4854,  0.9392,  0.0111],
          [-0.9666, -0.5393,  0.3010,  0.6850,  0.7560,  0.5727],
          [-0.0436, -0.9012,  0.4945,  0.9079,  0.6551,  0.3986]],
 
         [[-0.3056,  1.0390,  0.0308,  0.6546,  0.0471,  0.8311],
          [-0.7194,  0.3709,  0.8517,  0.2788,  0.5287,  0.8799],
          [-0.4567, -0.1643,  0.7193,  0.8177,  0.7088,  0.1128],
          [-0.3710, -0.7775,  0.7493,  0.7202,  0.5490,  0.0412]]],
        device='cuda:0'),
 torch.Size([2, 4, 6]))

In [22]:
test.shape, result.shape

(torch.Size([2, 4, 6]), torch.Size([2, 4, 6]))

In [7]:
# 회전 행렬 생성 최적화
def create_rotation_matrix_v2(d, thetas):
    """
    Create a batch of rotation matrices from the given thetas.
    
    Args:
        d (int): The dimensionality of the rotation matrix (must be even).
        thetas (Tensor): A tensor of shape (batch_size, seq_len, d/2) containing the rotation angles.
    
    Returns:
        Tensor: A tensor of shape (batch_size, seq_len, d, d) containing the rotation matrices.
    
    """
    batch_size, seq_len, _ = thetas.size()

    cos_t = torch.cos(thetas).unsqueeze(-1)
    sin_t = torch.sin(thetas).unsqueeze(-1)
    
    # 각각의 회전 행렬 요소 계산
    cos_stack = cos_t.repeat(1, 1, d//2, 2).reshape(batch_size, seq_len, d)
    sin_stack = sin_t.repeat(1, 1, d//2, 2).reshape(batch_size, seq_len, d)
    
    # 회전 행렬 생성
    R = torch.eye(d, device=thetas.device).unsqueeze(0).unsqueeze(0).repeat(batch_size, seq_len, 1, 1)
    R[:, :, ::2, ::2] = cos_stack
    R[:, :, 1::2, 1::2] = cos_stack
    R[:, :, ::2, 1::2] = -sin_stack
    R[:, :, 1::2, ::2] = sin_stack
    
    return R

create_rotation_matrix_v2(DIM_MODEL, var_tensor)
query = query.view(BATCH_SIZE * MAX_SEQ, DIM_MODEL)  # 배치 및 시퀀스 차원을 결합하여 행렬 곱셈을 위한 준비
R = R.view(BATCH_SIZE * MAX_SEQ, DIM_MODEL, DIM_MODEL)  # 배치 및 시퀀스 차원을 결합하여 행렬 곱셈을 위한 준비

result = torch.bmm(R, query.unsqueeze(-1)).squeeze(-1).view(BATCH_SIZE, MAX_SEQ, DIM_MODEL)  
result

RuntimeError: shape '[2, 4, 6]' is invalid for input of size 288

In [50]:
query = torch.rand(BATCH_SIZE, MAX_SEQ, DIM_MODEL, device=device)
query

tensor([[[0.5181, 0.0157, 0.2816, 0.8102, 0.7531, 0.3929],
         [0.2677, 0.4532, 0.2546, 0.8741, 0.9514, 0.7601],
         [0.8453, 0.2444, 0.2729, 0.9591, 0.9617, 0.2893],
         [0.7103, 0.3587, 0.7432, 0.9069, 0.2104, 0.2306]],

        [[0.7296, 0.3562, 0.6634, 0.2245, 0.1866, 0.4317],
         [0.5114, 0.0046, 0.2526, 0.8282, 0.1672, 0.7417],
         [0.1639, 0.5335, 0.7061, 0.8498, 0.0443, 0.5817],
         [0.5541, 0.2003, 0.7650, 0.3248, 0.2091, 0.3983]]], device='mps:0')

In [52]:
""" expand to hole word embedding shape """

for b in range(BATCH_SIZE):
    for s in range(MAX_SEQ):
        sub_rotary, sub_word = R[b, s, :, :], query[b, s, :]
        query[b, s, :] = torch.matmul(sub_rotary, sub_word)
        

In [53]:
class RoPE(nn.Module):
    def __init__(self, dim_model: int= 768):
        super().__init__()
        self.dim_model = dim_model
        self.i_arr = torch.arange(1, int(dim_model/2)+1) 
        self.theta = 10000**(-2*(self.i_arr - 1)/self.dim_model)
    
    def forward(self):
        print(self.theta.shape)
        print(self.theta)

tensor([[[ 0.2667,  0.4445,  0.2437,  0.8224,  0.7523,  0.3945],
         [-0.5235,  0.0548,  0.1725,  0.8940,  0.9481,  0.7642],
         [-0.8714, -0.1227,  0.1372,  0.9877,  0.9599,  0.2956],
         [-0.1929, -0.7720,  0.5630,  1.0285,  0.2084,  0.2324]],

        [[ 0.0945,  0.8064,  0.6523,  0.2550,  0.1857,  0.4321],
         [-0.2170,  0.4631,  0.1747,  0.8481,  0.1640,  0.7424],
         [-0.2375, -0.5051,  0.5813,  0.9395,  0.0405,  0.5820],
         [-0.2106, -0.5502,  0.6919,  0.4605,  0.2057,  0.4001]]],
       device='mps:0')

In [8]:
rotary_pos_enc = RoPE(DIM_MODEL)
rotary_pos_enc()

torch.Size([384])
tensor([1.0000e+00, 9.7630e-01, 9.5316e-01, 9.3057e-01, 9.0852e-01, 8.8699e-01,
        8.6596e-01, 8.4544e-01, 8.2540e-01, 8.0584e-01, 7.8674e-01, 7.6810e-01,
        7.4989e-01, 7.3212e-01, 7.1477e-01, 6.9783e-01, 6.8129e-01, 6.6515e-01,
        6.4938e-01, 6.3399e-01, 6.1897e-01, 6.0430e-01, 5.8997e-01, 5.7599e-01,
        5.6234e-01, 5.4901e-01, 5.3600e-01, 5.2330e-01, 5.1090e-01, 4.9879e-01,
        4.8697e-01, 4.7543e-01, 4.6416e-01, 4.5316e-01, 4.4242e-01, 4.3193e-01,
        4.2170e-01, 4.1170e-01, 4.0195e-01, 3.9242e-01, 3.8312e-01, 3.7404e-01,
        3.6517e-01, 3.5652e-01, 3.4807e-01, 3.3982e-01, 3.3177e-01, 3.2390e-01,
        3.1623e-01, 3.0873e-01, 3.0142e-01, 2.9427e-01, 2.8730e-01, 2.8049e-01,
        2.7384e-01, 2.6735e-01, 2.6102e-01, 2.5483e-01, 2.4879e-01, 2.4289e-01,
        2.3714e-01, 2.3152e-01, 2.2603e-01, 2.2067e-01, 2.1544e-01, 2.1034e-01,
        2.0535e-01, 2.0049e-01, 1.9573e-01, 1.9110e-01, 1.8657e-01, 1.8214e-01,
        1.7783e-01, 1.

In [8]:
""" test for parallel multi-head attention
1) permute. contiguous, reshape
2) directly call reshape from torch.matmul(attention_matrix, v)
"""

query = torch.rand(1, 8, 4, 2, device=device)
BS, SEQ_LEN, NUM_HEADS, DIM_HEADS = query.shape
print(f"original query: {query.shape, query}")

test1 = query.permute(0, 2, 1, 3).contiguous().reshape(-1, SEQ_LEN, NUM_HEADS*DIM_HEADS)
print(f"test1: {test1.shape, test1}")

test2 = query.reshape(-1, SEQ_LEN, NUM_HEADS*DIM_HEADS)
print(f"test2: {test2.shape, test2}")

original query: (torch.Size([1, 8, 4, 2]), tensor([[[[0.9168, 0.0523],
          [0.5834, 0.3648],
          [0.1445, 0.6178],
          [0.5795, 0.7004]],

         [[0.9812, 0.6302],
          [0.1581, 0.9039],
          [0.3351, 0.6687],
          [0.9043, 0.7900]],

         [[0.4702, 0.4620],
          [0.0865, 0.3290],
          [0.5967, 0.5650],
          [0.1954, 0.9916]],

         [[0.4202, 0.9388],
          [0.4909, 0.4016],
          [0.4052, 0.9085],
          [0.5089, 0.3139]],

         [[0.8943, 0.1367],
          [0.1901, 0.1161],
          [0.3471, 0.0353],
          [0.1799, 0.8890]],

         [[0.8990, 0.4810],
          [0.8155, 0.6152],
          [0.7975, 0.7668],
          [0.6558, 0.6093]],

         [[0.6960, 0.4698],
          [0.1492, 0.0688],
          [0.0338, 0.2450],
          [0.6533, 0.4803]],

         [[0.2576, 0.7272],
          [0.7371, 0.8030],
          [0.2122, 0.0413],
          [0.5253, 0.0612]]]], device='cuda:0'))
test1: (torch.Size([1, 8, 

In [13]:
""" source code from original linear transformers paper github
"""

max_seq = 5
dim_model = 8
num_heads = 4
dim_head = dim_model // num_heads
batch_size = 3

# 3, 5, 4, 2
query = torch.rand(batch_size, max_seq, num_heads, dim_head, device=device)
key = torch.rand(batch_size, max_seq, num_heads, dim_head, device=device)
value = torch.rand(batch_size, max_seq, num_heads, dim_head, device=device)

kv = torch.einsum("nshd,nshm->nhmd", key, value)
print(f"kv: {kv.shape, kv}")

z = 1/(torch.einsum("nlhd,nhd->nlh", query, key.sum(dim=1))+1e-6)
print(f"z: {z.shape, z}")

qkv = torch.einsum("nlhd,nhmd->nlhm", query, kv)
print(f"qkv: {qkv.shape, qkv}")

V = torch.einsum("nlhd,nhmd,nlh->nlhm", query, kv, z)
print(f"V: {V.shape, V}")

test_v = torch.einsum("nlhm,nlh->nlhm", qkv, z)
print(f"test_v: {test_v.shape, test_v}")

kv: (torch.Size([3, 4, 2, 2]), tensor([[[[1.8712, 1.7441],
          [2.0152, 1.8818]],

         [[1.2480, 1.6586],
          [0.9948, 1.8648]],

         [[1.8897, 1.1054],
          [1.4765, 1.0960]],

         [[1.1829, 2.3455],
          [1.1215, 1.7871]]],


        [[[1.2551, 1.4782],
          [0.6835, 0.8580]],

         [[2.0954, 2.2804],
          [1.2951, 1.9219]],

         [[1.5423, 1.0609],
          [2.4590, 1.7339]],

         [[2.1012, 1.5119],
          [1.1772, 0.9965]]],


        [[[1.3803, 2.0247],
          [1.0918, 1.3187]],

         [[1.6456, 1.3337],
          [0.9049, 1.3710]],

         [[1.0539, 1.3264],
          [1.2296, 1.8057]],

         [[1.3767, 1.8384],
          [0.7315, 0.7558]]]], device='mps:0'))
z: (torch.Size([3, 5, 4]), tensor([[[0.4404, 0.2922, 0.5909, 0.2315],
         [0.2317, 0.4030, 0.4618, 0.2901],
         [0.6622, 0.5419, 0.2251, 0.2437],
         [0.6661, 0.3618, 0.3576, 0.9171],
         [0.2153, 0.3808, 0.4363, 0.4122]],

       

In [114]:
""" alias of each tensor dimension for torch einsum
b: batch_size
s: sequence length
q: dim_head of query
k: dim_head of key
v: dim_head of value
"""
def kernel_fn(x: Tensor):
    return F.elu(x) + 1

query = torch.rand(batch_size, max_seq, dim_head, device=device)
key = torch.rand(batch_size, max_seq, dim_head, device=device)
value = torch.rand(batch_size, max_seq, dim_head, device=device)
padding_mask = torch.tensor([
    [0, 0, 0, 1, 1],
    [0, 0, 0, 0, 1],
    [0, 0, 1, 1, 1]],
    device=device
)
query, key = kernel_fn(query), kernel_fn(key)
query[padding_mask == 1], key[padding_mask == 1], value[padding_mask == 1] = 0, 0, 0

KV = torch.matmul(value.permute(0, 2, 1), key.permute(0, 1, 2))
print(f"KV: {KV.shape, KV}")

Z = 1 / torch.clamp(torch.mul(query,key.sum(dim=1).unsqueeze(1)).sum(dim=-1), min=1e-6)
print(f"Z: {Z.shape, Z}")

V = torch.einsum("bsq,bvk,bs->bsv", query, KV, Z)
print(f"V: {V.shape, V}")

KV: (torch.Size([3, 2, 2]), tensor([[[1.6945, 2.5599],
         [1.6986, 2.4920]],

        [[3.7326, 3.1925],
         [3.5027, 3.1093]],

        [[1.4581, 2.1359],
         [2.1605, 2.8166]]], device='cuda:0'))
Z: (torch.Size([3, 5]), tensor([[5.9278e-02, 8.4242e-02, 8.2484e-02, 1.0000e+06, 1.0000e+06],
        [4.6290e-02, 4.3917e-02, 6.5639e-02, 6.5267e-02, 1.0000e+06],
        [1.0978e-01, 1.0496e-01, 1.0000e+06, 1.0000e+06, 1.0000e+06]],
       device='cuda:0'))
V: (torch.Size([3, 5, 2]), tensor([[[0.9854, 0.9707],
         [1.0248, 1.0095],
         [1.0131, 0.9979],
         [0.0000, 0.0000],
         [0.0000, 0.0000]],

        [[1.1236, 1.0728],
         [1.1181, 1.0675],
         [1.1167, 1.0662],
         [1.1213, 1.0706],
         [0.0000, 0.0000]],

        [[1.0962, 1.5180],
         [1.1320, 1.5677],
         [0.0000, 0.0000],
         [0.0000, 0.0000],
         [0.0000, 0.0000]]], device='cuda:0'))


In [15]:
# [bs, num_heads, dim_head, max_seq]*[bs, num_heads, max_seq, dim_head] = [bs, num_heads, dim_head, dim_head]
KV = torch.matmul(value.permute(0, 2, 3, 1), key.permute(0, 2, 1, 3))  
print(f"KV: {KV.shape, KV}")

print(f"query: {query.shape}")
print(f"key sum: {key.sum(dim=1).shape}")
 
Z = 1 / torch.clamp(torch.mul(query,key.sum(dim=1).unsqueeze(1)).sum(dim=-1), min=1e-6)
print(f"Z: {Z, Z.shape}")

KV: (torch.Size([3, 4, 2, 2]), tensor([[[[1.8712, 1.7441],
          [2.0152, 1.8818]],

         [[1.2480, 1.6586],
          [0.9948, 1.8648]],

         [[1.8897, 1.1054],
          [1.4765, 1.0960]],

         [[1.1829, 2.3455],
          [1.1215, 1.7871]]],


        [[[1.2551, 1.4782],
          [0.6835, 0.8580]],

         [[2.0954, 2.2804],
          [1.2951, 1.9219]],

         [[1.5423, 1.0609],
          [2.4590, 1.7339]],

         [[2.1012, 1.5119],
          [1.1772, 0.9965]]],


        [[[1.3803, 2.0247],
          [1.0918, 1.3187]],

         [[1.6456, 1.3337],
          [0.9049, 1.3710]],

         [[1.0539, 1.3264],
          [1.2296, 1.8057]],

         [[1.3767, 1.8384],
          [0.7315, 0.7558]]]], device='mps:0'))
query: torch.Size([3, 5, 4, 2])
key sum: torch.Size([3, 4, 2])
Z: (tensor([[[0.4404, 0.2922, 0.5909, 0.2315],
         [0.2317, 0.4030, 0.4618, 0.2901],
         [0.6622, 0.5419, 0.2251, 0.2437],
         [0.6661, 0.3618, 0.3576, 0.9171],
         [0.

In [5]:
# test_query = query.reshape(batch_size, max_seq, dim_head)
# test_key = key.reshape(batch_size, max_seq, dim_head)
# test_value = value.reshape(batch_size, max_seq, dim_head)
# 
# KV = torch.matmul(test_key.permute(0, 2, 1), test_value)
# print(f"KV: {KV.shape, KV}")
# QKV = torch.matmul(test_query, KV)
# print(f"QKV: {QKV.shape, QKV}")

# summation_key = test_key.sum(dim=1).unsqueeze(1)
# print(f"key, summation_key: {test_key.shape, summation_key.shape}")
# 
# Z = 1/torch.clamp(torch.mul(test_query, summation_key), min=1e-6)
# print(f"Normalizer Z: {Z, Z.shape}")
# 
# linear_attn_matrix = torch.matmul(QKV, Z)
# print(f"linear_attn_matrix: {linear_attn_matrix.shape, linear_attn_matrix}")

RuntimeError: shape '[3, 5, 2]' is invalid for input of size 120

In [84]:
quadratic = 512*64*64*512 + 512*512*512*64
linear = 64*512*512*64 + 512*64*64*64

quadratic, linear, quadratic // linear

(9663676416, 1207959552, 8)

In [98]:
""" Inputs for Linear Attentions
"""

max_seq = 5
dim_model = 2
num_heads = 12
batch_size = 3
# dim_head = dim_model // num_heads
dim_head = dim_model

def kernel_fn(x: Tensor):
    return F.elu(x) + 1

query = torch.rand(batch_size, max_seq, dim_head, device=device)
key = torch.rand(batch_size, max_seq, dim_head, device=device)
value = torch.rand(batch_size, max_seq, dim_head, device=device)

padding_mask = torch.tensor([
    [0, 0, 0, 1, 1],
    [0, 0, 0, 0, 1],
    [0, 0, 1, 1, 1]],
    device=device
)
query.shape, key.shape, value.shape, padding_mask.shape

(torch.Size([3, 5, 2]),
 torch.Size([3, 5, 2]),
 torch.Size([3, 5, 2]),
 torch.Size([3, 5]))

In [99]:
""" Testing for Linear Attentions: KV Matrix
"""
k_query, k_key = kernel_fn(query), kernel_fn(key)
k_query[padding_mask == 1], k_key[padding_mask == 1], value[padding_mask == 1] = 0, 0, 0 

KV = torch.matmul(k_key.permute(0, 2, 1), value)
KV, KV.shape

(tensor([[[1.6423, 2.3725],
          [1.9186, 2.7227]],
 
         [[2.5099, 4.3219],
          [2.0842, 3.6178]],
 
         [[1.7249, 2.7095],
          [1.6283, 2.5576]]], device='cuda:0'),
 torch.Size([3, 2, 2]))

In [100]:
""" Testing for Linear Attentions: QKV
"""

QKV = torch.matmul(k_query, KV)
QKV, QKV.shape

(tensor([[[ 6.9833,  9.9921],
          [ 4.8467,  6.9438],
          [ 5.7299,  8.2035],
          [ 0.2364,  0.8221],
          [ 0.5723,  0.9037]],
 
         [[ 6.3969, 11.0627],
          [ 6.1451, 10.6128],
          [ 6.6618, 11.5190],
          [ 7.9548, 13.7416],
          [ 0.5258,  0.1820]],
 
         [[ 4.7064,  7.3925],
          [ 4.8010,  7.5412],
          [ 0.4323,  0.3144],
          [ 0.6062,  0.6203],
          [ 0.8389,  0.5675]]], device='cuda:0'),
 torch.Size([3, 5, 2]))

In [101]:
""" Testing for Linear Attentions: QKV / normalizer Z
softmax는 row-wise하게 정규화 하는데, 우리도 똑같이 정규화가 필요하지 않냐고 그래서 Z가 필요
여기서, QKV가 결국 row-wise 하게 정규화 되어야 한다는게 포인트임
그렇다면 Z의 크기는 16, 512, 64가 되어야 한다
"""

summation_key = k_key.sum(dim=1).unsqueeze(1).expand(-1, dim_head, -1).permute(0,2,1)

print(f"key, summation_key: {key.shape, summation_key.shape}")

Z = 1/torch.clamp(torch.matmul(k_query, summation_key), min=1e-6)
print(f"Normalizer Z: {Z, Z.shape}")

key, summation_key: (torch.Size([3, 5, 2]), torch.Size([3, 2, 2]))
Normalizer Z: (tensor([[[5.6876e-02, 5.6876e-02],
         [8.1770e-02, 8.1770e-02],
         [6.9246e-02, 6.9246e-02],
         [1.0000e+06, 1.0000e+06],
         [1.0000e+06, 1.0000e+06]],

        [[6.3809e-02, 6.3809e-02],
         [6.7458e-02, 6.7458e-02],
         [6.1383e-02, 6.1383e-02],
         [5.1966e-02, 5.1966e-02],
         [1.0000e+06, 1.0000e+06]],

        [[9.9465e-02, 9.9465e-02],
         [9.7511e-02, 9.7511e-02],
         [1.0000e+06, 1.0000e+06],
         [1.0000e+06, 1.0000e+06],
         [1.0000e+06, 1.0000e+06]]], device='cuda:0'), torch.Size([3, 5, 2]))


In [102]:
""" Testing for Linear Attentions: QKV / normalizer Z 
"""

linear_attn_matrix = torch.mul(QKV, Z)
print(f"linear_attn_matrix: {linear_attn_matrix.shape, linear_attn_matrix}")

linear_attn_matrix: (torch.Size([3, 5, 2]), tensor([[[3.9718e-01, 5.6831e-01],
         [3.9631e-01, 5.6779e-01],
         [3.9677e-01, 5.6806e-01],
         [2.3637e+05, 8.2210e+05],
         [5.7225e+05, 9.0375e+05]],

        [[4.0818e-01, 7.0590e-01],
         [4.1454e-01, 7.1592e-01],
         [4.0892e-01, 7.0707e-01],
         [4.1338e-01, 7.1410e-01],
         [5.2581e+05, 1.8200e+05]],

        [[4.6812e-01, 7.3529e-01],
         [4.6815e-01, 7.3535e-01],
         [4.3231e+05, 3.1439e+05],
         [6.0615e+05, 6.2035e+05],
         [8.3891e+05, 5.6749e+05]]], device='cuda:0'))


In [103]:
""" Comparing with pure self-attention
"""

attn_matrix = torch.matmul(query, key.transpose(-1, -2)) / torch.sqrt(torch.tensor(dim_head))
print(f"attn_matrix: {attn_matrix.shape, attn_matrix}")

padding_mask = padding_mask.unsqueeze(1)
attn_matrix = attn_matrix.masked_fill(padding_mask == 1, float('-inf'))
attention_dist = F.softmax(attn_matrix, dim=-1)
print(f"attention_dist: {attention_dist.shape, attention_dist}")

attention_matrix = torch.matmul(attention_dist, value)
print(f"attention_matrix: {attention_matrix.shape, attention_matrix}")

attn_matrix: (torch.Size([3, 5, 5]), tensor([[[0.3713, 0.6363, 1.0094, 0.5870, 0.7410],
         [0.1055, 0.2372, 0.3701, 0.2420, 0.2239],
         [0.2164, 0.4022, 0.6346, 0.3840, 0.4393],
         [0.3069, 0.4024, 0.6519, 0.3202, 0.5834],
         [0.1555, 0.2623, 0.4167, 0.2403, 0.3093]],

        [[0.2487, 0.1702, 0.1442, 0.2789, 0.3663],
         [0.0394, 0.2411, 0.1327, 0.3839, 0.3884],
         [0.2519, 0.2133, 0.1670, 0.3475, 0.4341],
         [0.2112, 0.4541, 0.2813, 0.7280, 0.7887],
         [0.0527, 0.1607, 0.0945, 0.2568, 0.2699]],

        [[0.3726, 0.5146, 0.2369, 0.3336, 0.5193],
         [0.4214, 0.5751, 0.1496, 0.1591, 0.4101],
         [0.7383, 1.0163, 0.4114, 0.5539, 0.9420],
         [0.6791, 0.9325, 0.3396, 0.4380, 0.8084],
         [0.5018, 0.6880, 0.2337, 0.2919, 0.5716]]], device='cuda:0'))
attention_dist: (torch.Size([3, 5, 5]), tensor([[[0.2383, 0.3106, 0.4511, 0.0000, 0.0000],
         [0.2904, 0.3313, 0.3783, 0.0000, 0.0000],
         [0.2686, 0.3234, 0.4080

In [25]:
""" Comparing with pure self-attention by KL-divergence
배치 별, 평균 총합 0.5 ~ 2 정도 차이남, 이게 보니까 처음에 랜덤 초기화 빨로 갈리네
"""

kl_div = F.kl_div(linear_attn_matrix.log(), attention_matrix, reduction='batchmean')
kl_div

tensor(inf, device='cuda:0')

In [33]:
""" Test code for applying padding masking to linear attention """
test_q = torch.randn(3, 5, 4, device=device)
test_k = torch.randn(3, 5, 4, device=device)
test_v = torch.randn(3, 5, 4, device=device)

padding_mask = torch.tensor([
    [0, 0, 0, 1, 1],
    [0, 0, 0, 0, 1],
    [0, 0, 1, 1, 1]],
    device=device
)
test_k[padding_mask == 1] = 0
test_k

tensor([[[-0.0264, -0.0698, -1.3971, -1.5667],
         [ 0.5706, -1.4718,  0.3398, -0.5280],
         [ 1.5465,  0.2453,  1.0911,  0.6062],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000]],

        [[ 0.4773,  1.9033,  0.6571,  0.4388],
         [ 0.1047, -0.3162, -3.4738, -0.4424],
         [ 0.4799,  0.7670, -1.0056, -0.4248],
         [ 1.2279, -0.7639,  1.4043, -0.2604],
         [ 0.0000,  0.0000,  0.0000,  0.0000]],

        [[-0.1115,  0.3349, -1.0625,  0.5592],
         [ 0.3361,  1.7460,  1.9226,  0.5757],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000]]], device='cuda:0')

In [92]:
class RoPE(nn.Module):
    def __init__(self, dim_model: int= 768):
        super().__init__()
        self.dim_model = dim_model
        self.i_arr = torch.arange(1, int(dim_model/2)+1)  # 세타값을 살리려면 
        self.theta = 10000**(-2*(self.i_arr - 1)/self.dim_model)
    
    def forward(self):
        print(self.i_arr.shape)
        print(self.i_arr)
        print(self.theta.shape)
        print(self.theta)

In [93]:
test = RoPE()
test()

torch.Size([384])
tensor([  1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,  14,
         15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,  28,
         29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,  42,
         43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,  56,
         57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,  70,
         71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,  84,
         85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,  98,
         99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112,
        113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126,
        127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140,
        141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154,
        155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168,
        169, 170, 171, 172, 173, 174, 175, 176