In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
import experiment.metrics.metric as metric
from experiment.metrics.metric import cosine_similarity

In [2]:
""" set default device to mps """

device = torch.device("cuda" if torch.cuda.is_available() else "mps")
device

device(type='cuda')

In [84]:
quadratic = 512*64*64*512 + 512*512*512*64
linear = 64*512*512*64 + 512*64*64*64

quadratic, linear, quadratic // linear

(9663676416, 1207959552, 8)

In [16]:
""" Inputs for Linear Attentions
"""

max_seq = 512
dim_model = 768
num_heads = 12
batch_size = 16
dim_head = dim_model // num_heads

def kernel_fn(x: Tensor):
    return F.elu(x) + 1

query = torch.rand(batch_size, max_seq, dim_head, device=device)
key = torch.rand(batch_size, max_seq, dim_head, device=device)
value = torch.rand(batch_size, max_seq, dim_head, device=device)

query.shape, key.shape, value.shape

(torch.Size([16, 512, 64]),
 torch.Size([16, 512, 64]),
 torch.Size([16, 512, 64]))

In [17]:
""" Testing for Linear Attentions: KV Matrix
"""
k_query, k_key = kernel_fn(query), kernel_fn(key)
KV = torch.matmul(k_key.permute(0, 2, 1), value)
KV, KV.shape

(tensor([[[379.2272, 363.0186, 405.7908,  ..., 374.7338, 377.7190, 375.9309],
          [382.8666, 358.6458, 400.0844,  ..., 372.0980, 374.7551, 371.8542],
          [382.7687, 360.3486, 408.0707,  ..., 377.3307, 377.7459, 373.8784],
          ...,
          [380.0161, 362.1071, 402.2729,  ..., 374.3331, 373.1011, 374.5012],
          [385.2000, 365.0188, 408.9546,  ..., 378.1088, 378.6559, 381.6468],
          [386.3977, 363.9138, 406.5035,  ..., 377.4268, 378.1671, 373.0061]],
 
         [[373.5042, 380.5840, 375.6325,  ..., 418.3877, 394.5324, 387.5174],
          [370.0285, 382.9928, 377.6545,  ..., 418.3036, 395.9686, 387.1384],
          [366.8623, 378.5196, 371.6406,  ..., 414.5091, 388.5123, 383.9243],
          ...,
          [364.2798, 377.8297, 369.8265,  ..., 409.8225, 388.3452, 379.5322],
          [365.6089, 376.0946, 366.6428,  ..., 405.7575, 384.7750, 378.0374],
          [372.2574, 383.5679, 375.1379,  ..., 416.8979, 396.5555, 389.8101]],
 
         [[386.6698, 386.072

In [18]:
""" Testing for Linear Attentions: QKV
"""

QKV = torch.matmul(k_query, KV)
QKV, QKV.shape

(tensor([[[36912.2305, 34959.8398, 39018.5078,  ..., 36197.7031,
           36283.9922, 36281.0859],
          [37543.7148, 35541.6406, 39672.7461,  ..., 36797.5664,
           36884.6523, 36891.7734],
          [37298.1992, 35312.1914, 39421.0391,  ..., 36573.7734,
           36654.2070, 36651.3984],
          ...,
          [37635.8438, 35628.3555, 39772.7500,  ..., 36887.6562,
           36976.1797, 36984.2109],
          [37077.3633, 35104.5781, 39195.0625,  ..., 36351.3633,
           36436.9688, 36436.4336],
          [36029.4844, 34113.1367, 38085.0508,  ..., 35314.2383,
           35399.2734, 35410.6836]],
 
         [[34804.4492, 35761.9805, 35283.5273,  ..., 39085.5000,
           36992.9805, 36303.0273],
          [35909.3945, 36904.7227, 36419.6562,  ..., 40332.0898,
           38173.2188, 37446.2070],
          [34745.9336, 35694.1250, 35222.6328,  ..., 39017.3516,
           36924.3242, 36230.8633],
          ...,
          [33928.2227, 34872.9297, 34411.9258,  ..., 38108

In [19]:
""" Testing for Linear Attentions: QKV / normalizer Z
softmax는 row-wise하게 정규화 하는데, 우리도 똑같이 정규화가 필요하지 않냐고 그래서 Z가 필요
여기서, QKV가 결국 row-wise 하게 정규화 되어야 한다는게 포인트임
그렇다면 Z의 크기는 16, 512, 64가 되어야 한다
"""

# summation_key = k_key.sum(dim=1).unsqueeze(1).expand(-1, dim_head, -1)
summation_key = k_key.sum(dim=1).unsqueeze(1).expand(-1, dim_head, -1).permute(0,2,1)

print(f"key, summation_key: {key.shape, summation_key.shape}")

Z = 1/torch.matmul(k_query, summation_key)
print(f"Normalizer Z: {Z, Z.shape}")

key, summation_key: (torch.Size([16, 512, 64]), torch.Size([16, 64, 64]))
Normalizer Z: (tensor([[[1.3546e-05, 1.3546e-05, 1.3546e-05,  ..., 1.3546e-05,
          1.3546e-05, 1.3546e-05],
         [1.3323e-05, 1.3323e-05, 1.3323e-05,  ..., 1.3323e-05,
          1.3323e-05, 1.3323e-05],
         [1.3408e-05, 1.3408e-05, 1.3408e-05,  ..., 1.3408e-05,
          1.3408e-05, 1.3408e-05],
         ...,
         [1.3290e-05, 1.3290e-05, 1.3290e-05,  ..., 1.3290e-05,
          1.3290e-05, 1.3290e-05],
         [1.3488e-05, 1.3488e-05, 1.3488e-05,  ..., 1.3488e-05,
          1.3488e-05, 1.3488e-05],
         [1.3881e-05, 1.3881e-05, 1.3881e-05,  ..., 1.3881e-05,
          1.3881e-05, 1.3881e-05]],

        [[1.3802e-05, 1.3802e-05, 1.3802e-05,  ..., 1.3802e-05,
          1.3802e-05, 1.3802e-05],
         [1.3376e-05, 1.3376e-05, 1.3376e-05,  ..., 1.3376e-05,
          1.3376e-05, 1.3376e-05],
         [1.3828e-05, 1.3828e-05, 1.3828e-05,  ..., 1.3828e-05,
          1.3828e-05, 1.3828e-05],
    

In [20]:
""" Testing for Linear Attentions: QKV / normalizer Z 
"""

linear_attn_matrix = torch.mul(QKV, Z)
print(f"linear_attn_matrix: {linear_attn_matrix.shape, linear_attn_matrix}")

linear_attn_matrix: (torch.Size([16, 512, 64]), tensor([[[0.5000, 0.4736, 0.5286,  ..., 0.4903, 0.4915, 0.4915],
         [0.5002, 0.4735, 0.5285,  ..., 0.4902, 0.4914, 0.4915],
         [0.5001, 0.4735, 0.5285,  ..., 0.4904, 0.4914, 0.4914],
         ...,
         [0.5002, 0.4735, 0.5286,  ..., 0.4902, 0.4914, 0.4915],
         [0.5001, 0.4735, 0.5287,  ..., 0.4903, 0.4915, 0.4915],
         [0.5001, 0.4735, 0.5287,  ..., 0.4902, 0.4914, 0.4915]],

        [[0.4804, 0.4936, 0.4870,  ..., 0.5394, 0.5106, 0.5010],
         [0.4803, 0.4936, 0.4871,  ..., 0.5395, 0.5106, 0.5009],
         [0.4805, 0.4936, 0.4870,  ..., 0.5395, 0.5106, 0.5010],
         ...,
         [0.4803, 0.4937, 0.4871,  ..., 0.5395, 0.5106, 0.5011],
         [0.4804, 0.4937, 0.4871,  ..., 0.5394, 0.5106, 0.5009],
         [0.4803, 0.4936, 0.4872,  ..., 0.5395, 0.5106, 0.5010]],

        [[0.4971, 0.5008, 0.5173,  ..., 0.4852, 0.4978, 0.5051],
         [0.4970, 0.5008, 0.5172,  ..., 0.4852, 0.4978, 0.5051],
         [

In [21]:
""" Comparing with pure self-attention
"""

attn_matrix = torch.matmul(query, key.transpose(-1, -2)) / torch.sqrt(torch.tensor(dim_head))
print(f"attn_matrix: {attn_matrix.shape, attn_matrix}")

attention_dist = F.softmax(attn_matrix, dim=-1)
print(f"attention_dist: {attention_dist.shape, attention_dist}")

attention_matrix = torch.matmul(attention_dist, value)
print(f"attention_matrix: {attention_matrix.shape, attention_matrix}")

attn_matrix: (torch.Size([16, 512, 512]), tensor([[[2.2749, 1.9622, 1.9245,  ..., 2.1339, 1.7951, 1.9170],
         [2.2627, 2.0734, 2.0455,  ..., 2.2357, 2.0193, 2.0438],
         [2.3995, 1.9777, 2.0033,  ..., 2.3385, 2.0694, 2.0637],
         ...,
         [2.3627, 2.0644, 2.1282,  ..., 2.2813, 1.9809, 2.0667],
         [2.4133, 2.0074, 1.9502,  ..., 2.2852, 1.9838, 2.0096],
         [2.1986, 1.6938, 1.7289,  ..., 2.0817, 1.8390, 1.8193]],

        [[1.8141, 1.9757, 2.1443,  ..., 1.9633, 1.9246, 2.0287],
         [2.1800, 1.8266, 2.1786,  ..., 2.0976, 1.9825, 2.0626],
         [1.7646, 1.8651, 1.9468,  ..., 1.9577, 1.9091, 1.8173],
         ...,
         [1.5570, 1.6722, 1.8400,  ..., 1.7102, 1.5921, 1.8502],
         [1.9877, 1.9390, 1.9734,  ..., 2.0285, 2.1212, 1.9973],
         [1.7255, 1.7074, 1.8447,  ..., 1.8424, 1.7684, 1.6644]],

        [[1.9951, 1.8562, 2.0399,  ..., 2.3830, 2.0018, 2.1632],
         [1.6052, 1.8875, 1.9085,  ..., 1.9347, 1.8365, 1.8557],
         [1.7836

In [22]:
""" Comparing with pure self-attention by KL-divergence
배치 별, 평균 총합 0.5 ~ 2 정도 차이남, 이게 보니까 처음에 랜덤 초기화 빨로 갈리네
"""

kl_div = F.kl_div(linear_attn_matrix.log(), attention_matrix, reduction='batchmean')
kl_div

tensor(1.6304, device='cuda:0')

In [33]:
""" Test code for applying padding masking to linear attention """
test_q = torch.randn(3, 5, 4, device=device)
test_k = torch.randn(3, 5, 4, device=device)
test_v = torch.randn(3, 5, 4, device=device)

padding_mask = torch.tensor([
    [0, 0, 0, 1, 1],
    [0, 0, 0, 0, 1],
    [0, 0, 1, 1, 1]],
    device=device
)
test_k[padding_mask == 1] = 0
test_k

tensor([[[-0.0264, -0.0698, -1.3971, -1.5667],
         [ 0.5706, -1.4718,  0.3398, -0.5280],
         [ 1.5465,  0.2453,  1.0911,  0.6062],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000]],

        [[ 0.4773,  1.9033,  0.6571,  0.4388],
         [ 0.1047, -0.3162, -3.4738, -0.4424],
         [ 0.4799,  0.7670, -1.0056, -0.4248],
         [ 1.2279, -0.7639,  1.4043, -0.2604],
         [ 0.0000,  0.0000,  0.0000,  0.0000]],

        [[-0.1115,  0.3349, -1.0625,  0.5592],
         [ 0.3361,  1.7460,  1.9226,  0.5757],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000]]], device='cuda:0')

In [92]:
class RoPE(nn.Module):
    def __init__(self, dim_model: int= 768):
        super().__init__()
        self.dim_model = dim_model
        self.i_arr = torch.arange(1, int(dim_model/2)+1)  # 세타값을 살리려면 
        self.theta = 10000**(-2*(self.i_arr - 1)/self.dim_model)
    
    def forward(self):
        print(self.i_arr.shape)
        print(self.i_arr)
        print(self.theta.shape)
        print(self.theta)

In [93]:
test = RoPE()
test()

torch.Size([384])
tensor([  1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,  14,
         15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,  28,
         29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,  42,
         43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,  56,
         57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,  70,
         71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,  84,
         85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,  98,
         99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112,
        113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126,
        127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140,
        141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154,
        155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168,
        169, 170, 171, 172, 173, 174, 175, 176