In [1]:
import torch

In [2]:
dim = 4
seq_len = 3 # total sequence length
base = 10000
position_id = 1 # position index of current token

In [7]:
inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
inv_freq, inv_freq.shape # dim/2

(tensor([1.0000, 0.0100]), torch.Size([2]))

In [8]:
t = torch.arange(seq_len, dtype=inv_freq.dtype)
t, t.shape

(tensor([0., 1., 2.]), torch.Size([3]))

In [9]:
freqs = torch.einsum('i,j->ij', t, inv_freq)
freqs, freqs.shape # (seq_len, dim/2)

(tensor([[0.0000, 0.0000],
         [1.0000, 0.0100],
         [2.0000, 0.0200]]),
 torch.Size([3, 2]))

## ChatGLM-6B

In [10]:
emb = torch.cat((freqs, freqs), dim=-1)
emb, emb.shape  # theta_0, theta_1, theta_0, theta_1

(tensor([[0.0000, 0.0000, 0.0000, 0.0000],
         [1.0000, 0.0100, 1.0000, 0.0100],
         [2.0000, 0.0200, 2.0000, 0.0200]]),
 torch.Size([3, 4]))

In [11]:
cos_cached = emb.cos()[:, None, :]
sin_cached = emb.sin()[:, None, :]
cos, sin = cos_cached[:seq_len, ...], sin_cached[:seq_len, ...]

In [12]:
cos, cos.shape

(tensor([[[ 1.0000,  1.0000,  1.0000,  1.0000]],
 
         [[ 0.5403,  0.9999,  0.5403,  0.9999]],
 
         [[-0.4161,  0.9998, -0.4161,  0.9998]]]),
 torch.Size([3, 1, 4]))

In [13]:
import torch.nn.functional as F

In [50]:
# cos_m_theta for token in give position
position_id = torch.LongTensor([[1]]) # the second position
cos_q = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2) # [sq, b, 1, hn]
cos_q, cos_q.shape  # pick up the second index

(tensor([[[[0.5403, 0.9999, 0.5403, 0.9999]]]]), torch.Size([1, 1, 1, 4]))

In [51]:
sin_q = F.embedding(position_id, sin.squeeze(1)).unsqueeze(2)
sin_q, sin_q.shape

(tensor([[[[0.8415, 0.0100, 0.8415, 0.0100]]]]), torch.Size([1, 1, 1, 4]))

In [52]:
def rotate_half(x):
    x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
    return torch.cat((-x2, x1), dim=x1.ndim - 1)  # dim=-1 triggers a bug in earlier torch versions


In [60]:
q = torch.arange(dim).view(1, 1, 1, dim) # seq_len, batch_size, heads, dim
q, q.shape

(tensor([[[[0, 1, 2, 3]]]]), torch.Size([1, 1, 1, 4]))

In [61]:
# q_0, q_1, q_2, q_3 -> -q_2, -q_4, q_0, q_1
rh_q = rotate_half(q)
rh_q, rh_q.shape

(tensor([[[[-2, -3,  0,  1]]]]), torch.Size([1, 1, 1, 4]))

In [62]:
qm = (q * cos_q) + (rh_q * sin_q)
qm, qm.shape

(tensor([[[[-1.6829,  0.9700,  1.0806,  3.0098]]]]), torch.Size([1, 1, 1, 4]))

## rotate matrix

In [66]:
mtheta = freqs[position_id].squeeze(0) # theta_0, theta_1
mtheta, mtheta.shape

(tensor([[1.0000, 0.0100]]), torch.Size([1, 2]))

In [69]:
cos_mtheta = mtheta.cos()
sin_mtheta = mtheta.sin()
cos_mtheta, cos_mtheta.shape # cos_m_theta_0, cos_m_theta_1

(tensor([[0.5403, 0.9999]]), torch.Size([1, 2]))

In [71]:
rmatrix = torch.Tensor([[cos_mtheta[0][0], -sin_mtheta[0][0], 0, 0], [sin_mtheta[0][0], cos_mtheta[0][0], 0, 0],
                        [0, 0, cos_mtheta[0][1], -sin_mtheta[0][1]], [0, 0, sin_mtheta[0][1], cos_mtheta[0][1]]])
rmatrix, rmatrix.shape

(tensor([[ 0.5403, -0.8415,  0.0000,  0.0000],
         [ 0.8415,  0.5403,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.9999, -0.0100],
         [ 0.0000,  0.0000,  0.0100,  0.9999]]),
 torch.Size([4, 4]))

In [73]:
q.shape

torch.Size([1, 1, 1, 4])

In [74]:
qm2 = torch.mm(rmatrix, q.view(dim, 1).float())
qm2, qm2.shape

(tensor([[-0.8415],
         [ 0.5403],
         [ 1.9699],
         [ 3.0198]]),
 torch.Size([4, 1]))

In [76]:
qm

tensor([[[[-1.6829,  0.9700,  1.0806,  3.0098]]]])

## ChatGLM2-6B

In [77]:
cache = torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1)
cache, cache.shape # cos_mtheta_0, cos_mtheta1, sin_mtehta_0, sin_mtheta_1

(tensor([[[ 1.0000,  0.0000],
          [ 1.0000,  0.0000]],
 
         [[ 0.5403,  0.8415],
          [ 0.9999,  0.0100]],
 
         [[-0.4161,  0.9093],
          [ 0.9998,  0.0200]]]),
 torch.Size([3, 2, 2]))

In [79]:
rope_cache = cache[position_id] # pick up position 1
rope_cache, rope_cache.shape

(tensor([[[[0.5403, 0.8415],
           [0.9999, 0.0100]]]]),
 torch.Size([1, 1, 2, 2]))

In [80]:
rot_dim = rope_cache.shape[-2] * 2
qshaped = q.reshape(1, -1, 1, rot_dim // 2, 2)
rot_dim, qshaped, qshaped.shape

(4,
 tensor([[[[[0, 1],
            [2, 3]]]]]),
 torch.Size([1, 1, 1, 2, 2]))

In [None]:
rope_cache = ropbe_cache.view(1, -1, 1, qshaped.size(3), 2)

In [85]:
qm3 = torch.stack(
        [
            qshaped[..., 0] * rope_cache[..., 0] - qshaped[..., 1] * rope_cache[..., 1],
            qshaped[..., 1] * rope_cache[..., 0] + qshaped[..., 0] * rope_cache[..., 1],
        ],
        -1,
    )
qm3 = x_out2.flatten(3)


In [86]:
qm3, qm3.shape

(tensor([[[[-0.8415,  0.5403,  1.9699,  3.0198]]]]), torch.Size([1, 1, 1, 4]))

In [87]:
qm2

tensor([[-0.8415],
        [ 0.5403],
        [ 1.9699],
        [ 3.0198]])

## LIaMA： 复数域

In [89]:
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # cos_mthta_0 + sin_mtheta_0*i, cos_mtheta_1, sin_mtheta_1*i
freqs_cis, freqs_cis

(tensor([[ 1.0000+0.0000j,  1.0000+0.0000j],
         [ 0.5403+0.8415j,  0.9999+0.0100j],
         [-0.4161+0.9093j,  0.9998+0.0200j]]),
 tensor([[ 1.0000+0.0000j,  1.0000+0.0000j],
         [ 0.5403+0.8415j,  0.9999+0.0100j],
         [-0.4161+0.9093j,  0.9998+0.0200j]]))

In [91]:
q_ = q.reshape(*q.shape[:-1], -1, 2).reshape(1, 1, 2, 2)
q_, q_.shape

(tensor([[[[0, 1],
           [2, 3]]]]),
 torch.Size([1, 1, 2, 2]))

In [92]:
q_complex_ = torch.view_as_complex(q_.float())
q_complex_, q_complex_.shape

(tensor([[[0.+1.j, 2.+3.j]]]), torch.Size([1, 1, 2]))

In [93]:
qm4 = torch.view_as_real(q_complex_ * freqs_cis).flatten(2)
qm4, qm4.shape

(tensor([[[ 0.0000,  1.0000,  2.0000,  3.0000],
          [-0.8415,  0.5403,  1.9699,  3.0198],
          [-0.9093, -0.4161,  1.9396,  3.0394]]]),
 torch.Size([1, 3, 4]))

In [94]:
qm3

tensor([[[[-0.8415,  0.5403,  1.9699,  3.0198]]]])