In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
import experiment.metrics.metric as metric
from experiment.metrics.metric import cosine_similarity

In [2]:
""" set default device to mps """

device = torch.device("cuda" if torch.cuda.is_available() else "mps")
device

device(type='mps')

In [84]:
quadratic = 512*64*64*512 + 512*512*512*64
linear = 64*512*512*64 + 512*64*64*64

quadratic, linear, quadratic // linear

(9663676416, 1207959552, 8)

In [44]:
""" Inputs for Linear Attentions
"""

max_seq = 512
dim_model = 768
num_heads = 12
batch_size = 16
dim_head = dim_model // num_heads

def kernel_fn(x: Tensor):
    return F.elu(x) + 1

query = torch.rand(batch_size, max_seq, dim_head, device=device)
key = torch.rand(batch_size, max_seq, dim_head, device=device)
value = torch.rand(batch_size, max_seq, dim_head, device=device)

query.shape, key.shape, value.shape

(torch.Size([16, 512, 64]),
 torch.Size([16, 512, 64]),
 torch.Size([16, 512, 64]))

In [45]:
""" Testing for Linear Attentions: KV Matrix
"""
k_query, k_key = kernel_fn(query), kernel_fn(key)
KV = torch.matmul(k_key.permute(0, 2, 1), value)
KV, KV.shape

(tensor([[[387.7043, 371.8534, 393.9587,  ..., 366.6582, 389.5907, 388.7321],
          [384.7640, 366.5524, 393.6349,  ..., 363.6440, 385.3273, 391.2770],
          [382.3161, 372.1201, 392.7029,  ..., 365.1796, 387.1248, 390.4623],
          ...,
          [387.3051, 374.5810, 395.6080,  ..., 364.7379, 387.8368, 391.8069],
          [383.2746, 371.5667, 395.2921,  ..., 364.0172, 388.0489, 391.9443],
          [390.4175, 375.6101, 397.2274,  ..., 368.6044, 385.3950, 394.7017]],
 
         [[382.1408, 385.7607, 388.2805,  ..., 392.3750, 388.4670, 380.8877],
          [382.6382, 385.0953, 383.5913,  ..., 387.5719, 377.2379, 377.3565],
          [384.5132, 390.9588, 386.0755,  ..., 384.6969, 380.5969, 381.2653],
          ...,
          [382.7216, 385.5114, 383.3484,  ..., 386.4922, 382.5864, 382.0888],
          [378.1702, 378.6817, 382.9681,  ..., 380.4366, 378.6287, 373.3374],
          [383.8071, 385.2278, 390.3502,  ..., 387.0571, 387.7598, 385.2114]],
 
         [[378.5798, 382.270

In [46]:
""" Testing for Linear Attentions: QKV
"""

QKV = torch.matmul(k_query, KV)
QKV, QKV.shape

(tensor([[[37280.8555, 36106.5078, 38063.3203,  ..., 35232.7578,
           37621.9023, 37787.2266],
          [35636.3633, 34503.3555, 36376.5078,  ..., 33671.8125,
           35951.5469, 36108.4570],
          [37547.0391, 36366.0469, 38342.8047,  ..., 35478.3047,
           37879.2656, 38039.2852],
          ...,
          [36239.8320, 35104.4883, 37016.8203,  ..., 34247.1875,
           36575.5156, 36744.7695],
          [35854.8789, 34725.8242, 36609.7891,  ..., 33877.2266,
           36178.7227, 36333.9648],
          [36674.7148, 35525.2969, 37465.9922,  ..., 34655.2773,
           37013.2930, 37181.4023]],
 
         [[37053.4492, 37189.0391, 37369.2539,  ..., 37365.4492,
           37059.3750, 36788.0547],
          [36076.2773, 36206.7773, 36375.6758,  ..., 36369.5977,
           36080.3594, 35822.3477],
          [37330.4180, 37472.3672, 37639.2539,  ..., 37643.1602,
           37354.6016, 37074.9961],
          ...,
          [35750.8477, 35888.6289, 36049.4062,  ..., 36049

In [52]:
""" Testing for Linear Attentions: QKV / normalizer Z
softmax는 row-wise하게 정규화 하는데, 우리도 똑같이 정규화가 필요하지 않냐고 그래서 Z가 필요
여기서, QKV가 결국 row-wise 하게 정규화 되어야 한다는게 포인트임
그렇다면 Z의 크기는 16, 512, 64가 되어야 한다
"""

# summation_key = k_key.sum(dim=1).unsqueeze(1).expand(-1, dim_head, -1)
summation_key = k_key.sum(dim=1).unsqueeze(1).expand(-1, 64, -1)

print(f"key, summation_key: {key.shape, summation_key.shape}")

Z = 1/torch.matmul(k_query, summation_key)
print(f"Normalizer Z: {Z, Z.shape}")

key, summation_key: (torch.Size([16, 512, 64]), torch.Size([16, 64, 64]))
Normalizer Z: (tensor([[[1.3309e-05, 1.3452e-05, 1.3403e-05,  ..., 1.3358e-05,
          1.3345e-05, 1.3336e-05],
         [1.3928e-05, 1.4077e-05, 1.4026e-05,  ..., 1.3980e-05,
          1.3966e-05, 1.3956e-05],
         [1.3218e-05, 1.3360e-05, 1.3312e-05,  ..., 1.3267e-05,
          1.3254e-05, 1.3245e-05],
         ...,
         [1.3690e-05, 1.3837e-05, 1.3787e-05,  ..., 1.3741e-05,
          1.3728e-05, 1.3718e-05],
         [1.3842e-05, 1.3990e-05, 1.3940e-05,  ..., 1.3893e-05,
          1.3879e-05, 1.3870e-05],
         [1.3528e-05, 1.3674e-05, 1.3624e-05,  ..., 1.3579e-05,
          1.3565e-05, 1.3556e-05]],

        [[1.3507e-05, 1.3617e-05, 1.3581e-05,  ..., 1.3526e-05,
          1.3750e-05, 1.3504e-05],
         [1.3873e-05, 1.3986e-05, 1.3950e-05,  ..., 1.3893e-05,
          1.4123e-05, 1.3871e-05],
         [1.3408e-05, 1.3518e-05, 1.3482e-05,  ..., 1.3427e-05,
          1.3650e-05, 1.3406e-05],
    

In [53]:
""" Testing for Linear Attentions: QKV / normalizer Z 
"""

linear_attn_matrix = QKV * Z
print(f"linear_attn_matrix: {linear_attn_matrix.shape, linear_attn_matrix}")

linear_attn_matrix: (torch.Size([16, 512, 64]), tensor([[[0.4962, 0.4857, 0.5102,  ..., 0.4707, 0.5021, 0.5039],
         [0.4963, 0.4857, 0.5102,  ..., 0.4707, 0.5021, 0.5039],
         [0.4963, 0.4859, 0.5104,  ..., 0.4707, 0.5021, 0.5038],
         ...,
         [0.4961, 0.4858, 0.5104,  ..., 0.4706, 0.5021, 0.5041],
         [0.4963, 0.4858, 0.5103,  ..., 0.4707, 0.5021, 0.5039],
         [0.4962, 0.4858, 0.5104,  ..., 0.4706, 0.5021, 0.5040]],

        [[0.5005, 0.5064, 0.5075,  ..., 0.5054, 0.5096, 0.4968],
         [0.5005, 0.5064, 0.5074,  ..., 0.5053, 0.5096, 0.4969],
         [0.5005, 0.5065, 0.5075,  ..., 0.5055, 0.5099, 0.4970],
         ...,
         [0.5004, 0.5064, 0.5073,  ..., 0.5053, 0.5095, 0.4967],
         [0.5006, 0.5067, 0.5076,  ..., 0.5055, 0.5098, 0.4971],
         [0.5005, 0.5065, 0.5075,  ..., 0.5054, 0.5097, 0.4970]],

        [[0.4995, 0.5003, 0.4826,  ..., 0.5131, 0.5121, 0.5006],
         [0.4996, 0.5004, 0.4828,  ..., 0.5132, 0.5123, 0.5007],
         [

In [54]:
""" Comparing with pure self-attention
"""

attn_matrix = torch.matmul(query, key.transpose(-1, -2)) / torch.sqrt(torch.tensor(dim_head))
print(f"attn_matrix: {attn_matrix.shape, attn_matrix}")

attention_dist = F.softmax(attn_matrix, dim=-1)
print(f"attention_dist: {attention_dist.shape, attention_dist}")

attention_matrix = torch.matmul(attention_dist, value)
print(f"attention_matrix: {attention_matrix.shape, attention_matrix}")

attn_matrix: (torch.Size([16, 512, 512]), tensor([[[2.1725, 1.7386, 2.0302,  ..., 1.8410, 2.2081, 2.2750],
         [1.6531, 1.3445, 1.7683,  ..., 1.8096, 1.8506, 1.9382],
         [1.9809, 1.7924, 2.1559,  ..., 2.0277, 2.0989, 2.2697],
         ...,
         [1.9047, 1.7059, 1.8541,  ..., 1.8300, 2.0202, 2.1081],
         [1.9011, 1.5412, 1.6208,  ..., 1.7042, 1.9113, 1.9657],
         [2.0518, 1.5056, 1.9919,  ..., 1.9880, 2.0154, 2.1064]],

        [[2.1710, 2.0610, 2.3059,  ..., 1.8151, 1.9755, 1.9143],
         [1.9178, 1.8013, 2.1319,  ..., 1.7848, 1.9060, 1.6572],
         [2.2061, 2.0478, 2.4350,  ..., 2.0977, 2.0709, 1.8986],
         ...,
         [1.7587, 1.8317, 2.0257,  ..., 1.6909, 1.7247, 1.5575],
         [2.0705, 2.2105, 2.2533,  ..., 1.8328, 2.0666, 1.8621],
         [2.2224, 2.3359, 2.3493,  ..., 2.1668, 2.2152, 2.0820]],

        [[1.8835, 2.0041, 2.1307,  ..., 1.8779, 1.8816, 1.9613],
         [1.8289, 1.8542, 1.9272,  ..., 1.7709, 1.8638, 1.7751],
         [2.1128

In [18]:
""" Comparing with pure self-attention by KL-divergence
배치 별, 평균 총합 0.5 ~ 2 정도 차이남, 이게 보니까 처음에 랜덤 초기화 빨로 갈리네
"""

kl_div = F.kl_div(linear_attn_matrix.log(), attention_matrix, reduction='batchmean')
kl_div

tensor(-1.4282, device='mps:0')

In [56]:
A = torch.tensor([
    [1, 2, 3],
    [4, 5, 6]
])
B = torch.tensor([
    [7, 8],
    [9, 10],
    [11, 12]
])
result = torch.matmul(A, B)
print(result)  # 출력: tensor([[ 58,  64], [139, 154]])


tensor([[ 58,  64],
        [139, 154]])


In [ ]:
""" Test code for applying padding masking to linear attention """
test_q = torch.randn(3, 5, 4, device=device)
test_k = torch.randn(3, 5, 4, device=device)
test_v = torch.randn(3, 5, 4, device=device)

padding_mask = torch.tensor([
    [0, 0, 0, 1, 1],
    [0, 0, 0, 0, 1],
    [0, 0, 1, 1, 1]],
    device=device
)



In [92]:
class RoPE(nn.Module):
    def __init__(self, dim_model: int= 768):
        super().__init__()
        self.dim_model = dim_model
        self.i_arr = torch.arange(1, int(dim_model/2)+1)  # 세타값을 살리려면 
        self.theta = 10000**(-2*(self.i_arr - 1)/self.dim_model)
    
    def forward(self):
        print(self.i_arr.shape)
        print(self.i_arr)
        print(self.theta.shape)
        print(self.theta)

In [93]:
test = RoPE()
test()

torch.Size([384])
tensor([  1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,  14,
         15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,  28,
         29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,  42,
         43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,  56,
         57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,  70,
         71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,  84,
         85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,  98,
         99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112,
        113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126,
        127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140,
        141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154,
        155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168,
        169, 170, 171, 172, 173, 174, 175, 176