In [104]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
import experiment.metrics.metric as metric
from experiment.metrics.metric import cosine_similarity

In [105]:
""" set default device to mps """

device = torch.device("cuda" if torch.cuda.is_available() else "mps")
device

device(type='cuda')

In [137]:
""" source code from original linear transformers paper github
"""

max_seq = 5
dim_model = 8
num_heads = 4
dim_head = dim_model // num_heads
batch_size = 3

# 3, 5, 4, 2
query = torch.rand(batch_size, max_seq, num_heads, dim_head, device=device)
key = torch.rand(batch_size, max_seq, num_heads, dim_head, device=device)
value = torch.rand(batch_size, max_seq, num_heads, dim_head, device=device)

kv = torch.einsum("nshd,nshm->nhmd", key, value)
print(f"kv: {kv.shape, kv}")

z = 1/(torch.einsum("nlhd,nhd->nlh", query, key.sum(dim=1))+1e-6)
print(f"z: {z.shape, z}")

V = torch.einsum("nlhd,nhmd,nlh->nlhm", query, kv, z)
print(f"V: {V.shape, V}")

result = V.reshape(batch_size, max_seq, dim_model)
print(f"result: {result.shape, result}")

kv: (torch.Size([3, 4, 2, 2]), tensor([[[[0.7384, 0.6831],
          [0.9784, 1.0978]],

         [[0.4672, 0.1484],
          [1.2292, 0.3618]],

         [[0.6822, 1.3503],
          [1.0276, 1.1962]],

         [[1.3133, 1.3501],
          [1.3980, 1.5897]]],


        [[[1.9564, 2.4758],
          [1.2201, 1.3477]],

         [[0.9246, 2.2588],
          [0.6081, 1.6430]],

         [[1.4280, 1.1428],
          [1.3536, 1.0312]],

         [[0.2574, 1.5583],
          [0.4280, 1.7079]]],


        [[[1.2459, 0.8152],
          [1.7307, 1.1222]],

         [[1.3764, 1.3455],
          [1.9562, 2.0202]],

         [[1.6096, 2.4803],
          [1.4274, 2.0940]],

         [[1.7893, 1.3124],
          [1.9845, 1.4693]]]], device='cuda:0'))
z: (torch.Size([3, 5, 4]), tensor([[[0.3384, 0.9140, 0.7687, 0.4793],
         [0.3341, 0.2972, 2.2779, 0.4126],
         [1.0245, 0.4344, 0.3763, 0.3041],
         [0.3448, 2.7105, 0.4792, 0.4867],
         [0.4329, 0.5398, 0.2911, 0.2871]],

      

In [138]:
test_query = query.reshape(batch_size, max_seq, dim_model)
test_key = key.reshape(batch_size, max_seq, dim_model)
test_value = value.reshape(batch_size, max_seq, dim_model)

KV = torch.matmul(test_key.permute(0, 2, 1), test_value)
print(f"KV: {KV.shape, KV}")
QKV = torch.matmul(test_query, KV)
print(f"QKV: {QKV.shape, QKV}")

# summation_key = test_key.sum(dim=1).unsqueeze(1)
# print(f"key, summation_key: {test_key.shape, summation_key.shape}")
# 
# Z = 1/torch.clamp(torch.mul(test_query, summation_key), min=1e-6)
# print(f"Normalizer Z: {Z, Z.shape}")
# 
# linear_attn_matrix = torch.matmul(QKV, Z)
# print(f"linear_attn_matrix: {linear_attn_matrix.shape, linear_attn_matrix}")

KV: (torch.Size([3, 8, 8]), tensor([[[0.7384, 0.9784, 0.3342, 0.7195, 1.3459, 1.3874, 1.4920, 1.4375],
         [0.6831, 1.0978, 0.5157, 0.5412, 1.2465, 1.4181, 1.8009, 1.5604],
         [1.1993, 0.4940, 0.4672, 1.2292, 1.0608, 1.3758, 1.7437, 2.1373],
         [0.3745, 0.0969, 0.1484, 0.3618, 0.3303, 0.5065, 0.6203, 0.6606],
         [0.7328, 0.4144, 0.4063, 0.4152, 0.6822, 1.0276, 1.4041, 1.3024],
         [1.0060, 1.1435, 0.4359, 0.9744, 1.3503, 1.1962, 1.4616, 1.8681],
         [0.8513, 0.5454, 0.3302, 0.6241, 0.9275, 1.1026, 1.3133, 1.3980],
         [0.6354, 1.1195, 0.4148, 0.9301, 1.1938, 1.0773, 1.3501, 1.5897]],

        [[1.9564, 1.2201, 2.0111, 1.7797, 2.0668, 1.9457, 1.7694, 1.5701],
         [2.4758, 1.3477, 2.0415, 2.1555, 2.3791, 2.2440, 1.8498, 1.8669],
         [0.6941, 0.6325, 0.9246, 0.6081, 0.8107, 0.6381, 0.7634, 0.7005],
         [1.8719, 1.3881, 2.2588, 1.6430, 1.9772, 2.0031, 2.3346, 1.9521],
         [1.4197, 1.1326, 1.6652, 1.1260, 1.4280, 1.3536, 1.8467, 1.67

In [84]:
quadratic = 512*64*64*512 + 512*512*512*64
linear = 64*512*512*64 + 512*64*64*64

quadratic, linear, quadratic // linear

(9663676416, 1207959552, 8)

In [98]:
""" Inputs for Linear Attentions
"""

max_seq = 5
dim_model = 2
num_heads = 12
batch_size = 3
# dim_head = dim_model // num_heads
dim_head = dim_model

def kernel_fn(x: Tensor):
    return F.elu(x) + 1

query = torch.rand(batch_size, max_seq, dim_head, device=device)
key = torch.rand(batch_size, max_seq, dim_head, device=device)
value = torch.rand(batch_size, max_seq, dim_head, device=device)

padding_mask = torch.tensor([
    [0, 0, 0, 1, 1],
    [0, 0, 0, 0, 1],
    [0, 0, 1, 1, 1]],
    device=device
)
query.shape, key.shape, value.shape, padding_mask.shape

(torch.Size([3, 5, 2]),
 torch.Size([3, 5, 2]),
 torch.Size([3, 5, 2]),
 torch.Size([3, 5]))

In [99]:
""" Testing for Linear Attentions: KV Matrix
"""
k_query, k_key = kernel_fn(query), kernel_fn(key)
k_query[padding_mask == 1], k_key[padding_mask == 1], value[padding_mask == 1] = 0, 0, 0 

KV = torch.matmul(k_key.permute(0, 2, 1), value)
KV, KV.shape

(tensor([[[1.6423, 2.3725],
          [1.9186, 2.7227]],
 
         [[2.5099, 4.3219],
          [2.0842, 3.6178]],
 
         [[1.7249, 2.7095],
          [1.6283, 2.5576]]], device='cuda:0'),
 torch.Size([3, 2, 2]))

In [100]:
""" Testing for Linear Attentions: QKV
"""

QKV = torch.matmul(k_query, KV)
QKV, QKV.shape

(tensor([[[ 6.9833,  9.9921],
          [ 4.8467,  6.9438],
          [ 5.7299,  8.2035],
          [ 0.2364,  0.8221],
          [ 0.5723,  0.9037]],
 
         [[ 6.3969, 11.0627],
          [ 6.1451, 10.6128],
          [ 6.6618, 11.5190],
          [ 7.9548, 13.7416],
          [ 0.5258,  0.1820]],
 
         [[ 4.7064,  7.3925],
          [ 4.8010,  7.5412],
          [ 0.4323,  0.3144],
          [ 0.6062,  0.6203],
          [ 0.8389,  0.5675]]], device='cuda:0'),
 torch.Size([3, 5, 2]))

In [101]:
""" Testing for Linear Attentions: QKV / normalizer Z
softmax는 row-wise하게 정규화 하는데, 우리도 똑같이 정규화가 필요하지 않냐고 그래서 Z가 필요
여기서, QKV가 결국 row-wise 하게 정규화 되어야 한다는게 포인트임
그렇다면 Z의 크기는 16, 512, 64가 되어야 한다
"""

summation_key = k_key.sum(dim=1).unsqueeze(1).expand(-1, dim_head, -1).permute(0,2,1)

print(f"key, summation_key: {key.shape, summation_key.shape}")

Z = 1/torch.clamp(torch.matmul(k_query, summation_key), min=1e-6)
print(f"Normalizer Z: {Z, Z.shape}")

key, summation_key: (torch.Size([3, 5, 2]), torch.Size([3, 2, 2]))
Normalizer Z: (tensor([[[5.6876e-02, 5.6876e-02],
         [8.1770e-02, 8.1770e-02],
         [6.9246e-02, 6.9246e-02],
         [1.0000e+06, 1.0000e+06],
         [1.0000e+06, 1.0000e+06]],

        [[6.3809e-02, 6.3809e-02],
         [6.7458e-02, 6.7458e-02],
         [6.1383e-02, 6.1383e-02],
         [5.1966e-02, 5.1966e-02],
         [1.0000e+06, 1.0000e+06]],

        [[9.9465e-02, 9.9465e-02],
         [9.7511e-02, 9.7511e-02],
         [1.0000e+06, 1.0000e+06],
         [1.0000e+06, 1.0000e+06],
         [1.0000e+06, 1.0000e+06]]], device='cuda:0'), torch.Size([3, 5, 2]))


In [102]:
""" Testing for Linear Attentions: QKV / normalizer Z 
"""

linear_attn_matrix = torch.mul(QKV, Z)
print(f"linear_attn_matrix: {linear_attn_matrix.shape, linear_attn_matrix}")

linear_attn_matrix: (torch.Size([3, 5, 2]), tensor([[[3.9718e-01, 5.6831e-01],
         [3.9631e-01, 5.6779e-01],
         [3.9677e-01, 5.6806e-01],
         [2.3637e+05, 8.2210e+05],
         [5.7225e+05, 9.0375e+05]],

        [[4.0818e-01, 7.0590e-01],
         [4.1454e-01, 7.1592e-01],
         [4.0892e-01, 7.0707e-01],
         [4.1338e-01, 7.1410e-01],
         [5.2581e+05, 1.8200e+05]],

        [[4.6812e-01, 7.3529e-01],
         [4.6815e-01, 7.3535e-01],
         [4.3231e+05, 3.1439e+05],
         [6.0615e+05, 6.2035e+05],
         [8.3891e+05, 5.6749e+05]]], device='cuda:0'))


In [103]:
""" Comparing with pure self-attention
"""

attn_matrix = torch.matmul(query, key.transpose(-1, -2)) / torch.sqrt(torch.tensor(dim_head))
print(f"attn_matrix: {attn_matrix.shape, attn_matrix}")

padding_mask = padding_mask.unsqueeze(1)
attn_matrix = attn_matrix.masked_fill(padding_mask == 1, float('-inf'))
attention_dist = F.softmax(attn_matrix, dim=-1)
print(f"attention_dist: {attention_dist.shape, attention_dist}")

attention_matrix = torch.matmul(attention_dist, value)
print(f"attention_matrix: {attention_matrix.shape, attention_matrix}")

attn_matrix: (torch.Size([3, 5, 5]), tensor([[[0.3713, 0.6363, 1.0094, 0.5870, 0.7410],
         [0.1055, 0.2372, 0.3701, 0.2420, 0.2239],
         [0.2164, 0.4022, 0.6346, 0.3840, 0.4393],
         [0.3069, 0.4024, 0.6519, 0.3202, 0.5834],
         [0.1555, 0.2623, 0.4167, 0.2403, 0.3093]],

        [[0.2487, 0.1702, 0.1442, 0.2789, 0.3663],
         [0.0394, 0.2411, 0.1327, 0.3839, 0.3884],
         [0.2519, 0.2133, 0.1670, 0.3475, 0.4341],
         [0.2112, 0.4541, 0.2813, 0.7280, 0.7887],
         [0.0527, 0.1607, 0.0945, 0.2568, 0.2699]],

        [[0.3726, 0.5146, 0.2369, 0.3336, 0.5193],
         [0.4214, 0.5751, 0.1496, 0.1591, 0.4101],
         [0.7383, 1.0163, 0.4114, 0.5539, 0.9420],
         [0.6791, 0.9325, 0.3396, 0.4380, 0.8084],
         [0.5018, 0.6880, 0.2337, 0.2919, 0.5716]]], device='cuda:0'))
attention_dist: (torch.Size([3, 5, 5]), tensor([[[0.2383, 0.3106, 0.4511, 0.0000, 0.0000],
         [0.2904, 0.3313, 0.3783, 0.0000, 0.0000],
         [0.2686, 0.3234, 0.4080

In [25]:
""" Comparing with pure self-attention by KL-divergence
배치 별, 평균 총합 0.5 ~ 2 정도 차이남, 이게 보니까 처음에 랜덤 초기화 빨로 갈리네
"""

kl_div = F.kl_div(linear_attn_matrix.log(), attention_matrix, reduction='batchmean')
kl_div

tensor(inf, device='cuda:0')

In [33]:
""" Test code for applying padding masking to linear attention """
test_q = torch.randn(3, 5, 4, device=device)
test_k = torch.randn(3, 5, 4, device=device)
test_v = torch.randn(3, 5, 4, device=device)

padding_mask = torch.tensor([
    [0, 0, 0, 1, 1],
    [0, 0, 0, 0, 1],
    [0, 0, 1, 1, 1]],
    device=device
)
test_k[padding_mask == 1] = 0
test_k

tensor([[[-0.0264, -0.0698, -1.3971, -1.5667],
         [ 0.5706, -1.4718,  0.3398, -0.5280],
         [ 1.5465,  0.2453,  1.0911,  0.6062],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000]],

        [[ 0.4773,  1.9033,  0.6571,  0.4388],
         [ 0.1047, -0.3162, -3.4738, -0.4424],
         [ 0.4799,  0.7670, -1.0056, -0.4248],
         [ 1.2279, -0.7639,  1.4043, -0.2604],
         [ 0.0000,  0.0000,  0.0000,  0.0000]],

        [[-0.1115,  0.3349, -1.0625,  0.5592],
         [ 0.3361,  1.7460,  1.9226,  0.5757],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000]]], device='cuda:0')

In [92]:
class RoPE(nn.Module):
    def __init__(self, dim_model: int= 768):
        super().__init__()
        self.dim_model = dim_model
        self.i_arr = torch.arange(1, int(dim_model/2)+1)  # 세타값을 살리려면 
        self.theta = 10000**(-2*(self.i_arr - 1)/self.dim_model)
    
    def forward(self):
        print(self.i_arr.shape)
        print(self.i_arr)
        print(self.theta.shape)
        print(self.theta)

In [93]:
test = RoPE()
test()

torch.Size([384])
tensor([  1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,  14,
         15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,  28,
         29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,  42,
         43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,  56,
         57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,  70,
         71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,  84,
         85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,  98,
         99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112,
        113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126,
        127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140,
        141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154,
        155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168,
        169, 170, 171, 172, 173, 174, 175, 176