In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
import experiment.metrics.metric as metric
from experiment.metrics.metric import cosine_similarity

In [2]:
""" set default device to mps """

device = torch.device("cuda" if torch.cuda.is_available() else "mps")
device

device(type='cuda')

In [84]:
quadratic = 512*64*64*512 + 512*512*512*64
linear = 64*512*512*64 + 512*64*64*64

quadratic, linear, quadratic // linear

(9663676416, 1207959552, 8)

In [84]:
""" Inputs for Linear Attentions
"""

max_seq = 5
dim_model = 2
num_heads = 12
batch_size = 3
# dim_head = dim_model // num_heads
dim_head = dim_model

def kernel_fn(x: Tensor):
    return F.elu(x) + 1

query = torch.rand(batch_size, max_seq, dim_head, device=device)
key = torch.rand(batch_size, max_seq, dim_head, device=device)
value = torch.rand(batch_size, max_seq, dim_head, device=device)
padding_mask = torch.tensor([
    [0, 0, 0, 1, 1],
    [0, 0, 0, 0, 1],
    [0, 0, 1, 1, 1]],
    device=device
)
query.shape, key.shape, value.shape, padding_mask.shape

(torch.Size([3, 5, 2]),
 torch.Size([3, 5, 2]),
 torch.Size([3, 5, 2]),
 torch.Size([3, 5]))

In [85]:
""" Testing for Linear Attentions: KV Matrix
"""
k_query, k_key = kernel_fn(query), kernel_fn(key)
k_query[padding_mask == 1], k_key[padding_mask == 1], value[padding_mask == 1] = 0, 0, 0 
KV = torch.matmul(k_key.permute(0, 2, 1), value)
KV, KV.shape

(tensor([[[2.6345, 1.1583],
          [2.8550, 1.1325]],
 
         [[3.4392, 2.3450],
          [3.0084, 2.1036]],
 
         [[2.5609, 1.5268],
          [2.3060, 1.3550]]], device='cuda:0'),
 torch.Size([3, 2, 2]))

In [86]:
""" Testing for Linear Attentions: QKV
"""

QKV = torch.matmul(k_query, KV)
QKV, QKV.shape

(tensor([[[ 7.4525,  3.1025],
          [ 8.6008,  3.6039],
          [ 9.1468,  3.7965],
          [ 0.0000,  0.0000],
          [ 0.0000,  0.0000]],
 
         [[ 9.4990,  6.5660],
          [11.1671,  7.7140],
          [ 8.8458,  6.1189],
          [ 7.9799,  5.5069],
          [ 0.0000,  0.0000]],
 
         [[ 5.3303,  3.1553],
          [ 7.4697,  4.4315],
          [ 0.0000,  0.0000],
          [ 0.0000,  0.0000],
          [ 0.0000,  0.0000]]], device='cuda:0'),
 torch.Size([3, 5, 2]))

In [87]:
""" Testing for Linear Attentions: QKV / normalizer Z
softmax는 row-wise하게 정규화 하는데, 우리도 똑같이 정규화가 필요하지 않냐고 그래서 Z가 필요
여기서, QKV가 결국 row-wise 하게 정규화 되어야 한다는게 포인트임
그렇다면 Z의 크기는 16, 512, 64가 되어야 한다
"""

summation_key = k_key.sum(dim=1).unsqueeze(1).expand(-1, dim_head, -1).permute(0,2,1)

print(f"key, summation_key: {key.shape, summation_key.shape}")

Z = 1/torch.clamp(torch.matmul(k_query, summation_key), min=1e-6)
print(f"Normalizer Z: {Z, Z.shape}")

key, summation_key: (torch.Size([3, 5, 2]), torch.Size([3, 2, 2]))
Normalizer Z: (tensor([[[8.6900e-02, 8.6900e-02],
         [7.5102e-02, 7.5102e-02],
         [7.0888e-02, 7.0888e-02],
         [1.0000e+06, 1.0000e+06],
         [1.0000e+06, 1.0000e+06]],

        [[5.3556e-02, 5.3556e-02],
         [4.5500e-02, 4.5500e-02],
         [5.7585e-02, 5.7585e-02],
         [6.3557e-02, 6.3557e-02],
         [1.0000e+06, 1.0000e+06]],

        [[1.3892e-01, 1.3892e-01],
         [9.8719e-02, 9.8719e-02],
         [1.0000e+06, 1.0000e+06],
         [1.0000e+06, 1.0000e+06],
         [1.0000e+06, 1.0000e+06]]], device='cuda:0'), torch.Size([3, 5, 2]))


In [88]:
""" Testing for Linear Attentions: QKV / normalizer Z 
"""

linear_attn_matrix = torch.mul(QKV, Z)
print(f"linear_attn_matrix: {linear_attn_matrix.shape, linear_attn_matrix}")

linear_attn_matrix: (torch.Size([3, 5, 2]), tensor([[[0.6476, 0.2696],
         [0.6459, 0.2707],
         [0.6484, 0.2691],
         [0.0000, 0.0000],
         [0.0000, 0.0000]],

        [[0.5087, 0.3516],
         [0.5081, 0.3510],
         [0.5094, 0.3524],
         [0.5072, 0.3500],
         [0.0000, 0.0000]],

        [[0.7405, 0.4383],
         [0.7374, 0.4375],
         [0.0000, 0.0000],
         [0.0000, 0.0000],
         [0.0000, 0.0000]]], device='cuda:0'))


In [89]:
""" Comparing with pure self-attention
"""

attn_matrix = torch.matmul(query, key.transpose(-1, -2)) / torch.sqrt(torch.tensor(dim_head))
print(f"attn_matrix: {attn_matrix.shape, attn_matrix}")

padding_mask = padding_mask.unsqueeze(1)
attn_matrix = attn_matrix.masked_fill(padding_mask == 1, float('-inf'))
attention_dist = F.softmax(attn_matrix, dim=-1)
print(f"attention_dist: {attention_dist.shape, attention_dist}")

attention_matrix = torch.matmul(attention_dist, value)
print(f"attention_matrix: {attention_matrix.shape, attention_matrix}")

attn_matrix: (torch.Size([3, 5, 5]), tensor([[[0.4147, 0.0776, 0.1389, 0.2053, 0.2956],
         [0.6370, 0.0975, 0.2554, 0.4823, 0.4708],
         [0.7776, 0.1510, 0.2500, 0.3431, 0.5501],
         [0.4173, 0.0418, 0.2101, 0.4855, 0.3255],
         [0.5681, 0.1091, 0.1849, 0.2598, 0.4028]],

        [[0.6696, 0.2685, 0.2128, 0.3013, 0.1852],
         [1.0094, 0.3946, 0.3941, 0.5482, 0.3113],
         [0.5430, 0.2241, 0.1260, 0.1845, 0.1298],
         [0.3224, 0.1239, 0.1412, 0.1947, 0.1061],
         [0.6482, 0.2254, 0.4568, 0.6134, 0.2892]],

        [[0.1210, 0.0444, 0.1299, 0.0543, 0.0903],
         [0.6136, 0.4479, 0.6321, 0.6190, 0.3482],
         [0.4278, 0.1398, 0.4614, 0.1657, 0.3278],
         [0.3293, 0.2082, 0.3431, 0.2826, 0.2027],
         [0.3856, 0.1955, 0.4076, 0.2564, 0.2612]]], device='cuda:0'))
attention_dist: (torch.Size([3, 5, 5]), tensor([[[0.4044, 0.2887, 0.3069, 0.0000, 0.0000],
         [0.4413, 0.2573, 0.3013, 0.0000, 0.0000],
         [0.4707, 0.2516, 0.2777

In [55]:
""" Comparing with pure self-attention by KL-divergence
배치 별, 평균 총합 0.5 ~ 2 정도 차이남, 이게 보니까 처음에 랜덤 초기화 빨로 갈리네
"""

kl_div = F.kl_div(linear_attn_matrix.log(), attention_matrix, reduction='batchmean')
kl_div

tensor(nan, device='cuda:0')

In [33]:
""" Test code for applying padding masking to linear attention """
test_q = torch.randn(3, 5, 4, device=device)
test_k = torch.randn(3, 5, 4, device=device)
test_v = torch.randn(3, 5, 4, device=device)

padding_mask = torch.tensor([
    [0, 0, 0, 1, 1],
    [0, 0, 0, 0, 1],
    [0, 0, 1, 1, 1]],
    device=device
)
test_k[padding_mask == 1] = 0
test_k

tensor([[[-0.0264, -0.0698, -1.3971, -1.5667],
         [ 0.5706, -1.4718,  0.3398, -0.5280],
         [ 1.5465,  0.2453,  1.0911,  0.6062],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000]],

        [[ 0.4773,  1.9033,  0.6571,  0.4388],
         [ 0.1047, -0.3162, -3.4738, -0.4424],
         [ 0.4799,  0.7670, -1.0056, -0.4248],
         [ 1.2279, -0.7639,  1.4043, -0.2604],
         [ 0.0000,  0.0000,  0.0000,  0.0000]],

        [[-0.1115,  0.3349, -1.0625,  0.5592],
         [ 0.3361,  1.7460,  1.9226,  0.5757],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000]]], device='cuda:0')

In [92]:
class RoPE(nn.Module):
    def __init__(self, dim_model: int= 768):
        super().__init__()
        self.dim_model = dim_model
        self.i_arr = torch.arange(1, int(dim_model/2)+1)  # 세타값을 살리려면 
        self.theta = 10000**(-2*(self.i_arr - 1)/self.dim_model)
    
    def forward(self):
        print(self.i_arr.shape)
        print(self.i_arr)
        print(self.theta.shape)
        print(self.theta)

In [93]:
test = RoPE()
test()

torch.Size([384])
tensor([  1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,  14,
         15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,  28,
         29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,  42,
         43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,  56,
         57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,  70,
         71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,  84,
         85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,  98,
         99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112,
        113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126,
        127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140,
        141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154,
        155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168,
        169, 170, 171, 172, 173, 174, 175, 176