In [3]:
import sys
import os
# Add the project root directory to the Python path
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
sys.path.append(project_root)


## Rope 2D

In [4]:
import torch
from src.utils.rope2d import get_rope_2d_angles, apply_2d_rope

In [5]:
batch_size, seq_len, embed_dim = 1, 4, 768

torch.manual_seed(42)
Q = torch.randn((batch_size, seq_len, embed_dim))
Q[0][1] = Q[0][0]
Q[0][3] = Q[0][0]
# Q[1][3] = Q[1][0]
print(Q.shape)
print("Q:\n", Q)
cos_theta, sin_theta = get_rope_2d_angles(embed_dim, int(seq_len**0.5))

Q_r = apply_2d_rope(Q, cos_theta, sin_theta)
print("Q_r:\n", Q_r)

torch.Size([1, 4, 768])
Q:
 tensor([[[ 1.9269,  1.4873,  0.9007,  ..., -1.6034, -0.4298,  0.5762],
         [ 1.9269,  1.4873,  0.9007,  ..., -1.6034, -0.4298,  0.5762],
         [-0.8497, -0.6987, -0.2052,  ..., -0.0298,  1.2715,  1.0849],
         [ 1.9269,  1.4873,  0.9007,  ..., -1.6034, -0.4298,  0.5762]]])
Q_r:
 tensor([[[ 1.9269,  1.4873,  0.9007,  ..., -1.6034, -0.4298,  0.5762],
         [-0.2104,  2.4250,  2.2381,  ..., -1.6034, -0.4298,  0.5762],
         [-0.8497, -0.6987, -0.2052,  ..., -0.0298,  1.2714,  1.0850],
         [-0.2104,  2.4250,  2.2381,  ..., -1.6034, -0.4299,  0.5761]]])


In [6]:
# Compute L2 norm before and after rotation
norm_before = torch.norm(Q, dim=-1)     # [bsz, seq_len]
norm_after  = torch.norm(Q_r, dim=-1)   # [bsz, seq_len]
# Check equality
torch.allclose(norm_before, norm_after)

True

In [7]:
# Check attention score (self attention) before & after rotation
batch_size, seq_len, embed_dim= Q.shape

# attention score
score_before = torch.matmul(Q, Q.transpose(-1, -2)) / (embed_dim**0.5)
score_before = torch.softmax(score_before, dim=-1)

score_after  = torch.matmul(Q_r, Q_r.transpose(-1, -2)) / (embed_dim**0.5)
score_after  = torch.softmax(score_after, dim=-1)


diff = (score_before - score_after).abs()
# round to 2 digits
torch.set_printoptions(precision=2, sci_mode=False)
print(diff)

tensor([[[    0.15,     0.04,     0.00,     0.11],
         [    0.08,     0.09,     0.00,     0.01],
         [    0.00,     0.00,     0.00,     0.00],
         [    0.13,     0.01,     0.00,     0.12]]])


## Absolute Position Encoding

In [8]:
from src.utils.abs_pos_embed import get_2d_sincos_pos_embed

abs_pos_embed = get_2d_sincos_pos_embed(embed_dim, int(seq_len**0.5))
abs_pos_embed = torch.from_numpy(abs_pos_embed)
Q_abs_r = Q+abs_pos_embed.unsqueeze(0)


In [9]:
# Compute L2 norm before and after rotation
norm_before = torch.norm(Q, dim=-1)     # [bsz, seq_len]
norm_after  = torch.norm(Q_abs_r, dim=-1)   # [bsz, seq_len]
# Check equality
torch.allclose(norm_before, norm_after, atol=1e0)

False

In [10]:
# Check attention score (self attention) before & after rotation
batch_size, seq_len, embed_dim= Q.shape

score_after_abs  = torch.matmul(Q_abs_r, Q_abs_r.transpose(-1, -2)) / (embed_dim**0.5)
score_after_abs  = torch.softmax(score_after, dim=-1)

diff_abs = (score_before - score_after_abs).abs()
# round to 2 digits
torch.set_printoptions(precision=2, sci_mode=False)
diff_abs, diff_abs.mean()

(tensor([[[0.02, 0.08, 0.19, 0.09],
          [0.08, 0.04, 0.19, 0.07],
          [0.17, 0.17, 0.52, 0.17],
          [0.10, 0.06, 0.19, 0.03]]]),
 tensor(0.14))

In [11]:
# self attention with Q shape [1, 4, 768]
# Q[0][1] = Q[0][0]
# Q[0][3] = Q[0][0]
score_before # attention score before apply position embeding

tensor([[[    0.33,     0.33,     0.00,     0.33],
         [    0.33,     0.33,     0.00,     0.33],
         [    0.00,     0.00,     1.00,     0.00],
         [    0.33,     0.33,     0.00,     0.33]]])

In [12]:
print((score_before - score_after).abs().mean())
score_after # rope 2d

tensor(0.05)


tensor([[[    0.48,     0.29,     0.00,     0.22],
         [    0.26,     0.42,     0.00,     0.32],
         [    0.00,     0.00,     1.00,     0.00],
         [    0.21,     0.34,     0.00,     0.45]]])

In [13]:
print((score_before - score_after_abs).abs().mean())
score_after_abs # abs position encoding 2d

tensor(0.14)


tensor([[[0.31, 0.26, 0.19, 0.24],
         [0.25, 0.29, 0.19, 0.26],
         [0.17, 0.17, 0.48, 0.17],
         [0.24, 0.27, 0.19, 0.30]]])