In [2]:
import math
import torch
from torch import nn

#生成用于注意力机制中的斜率（slopes）
def get_slopes(n_heads: int):
    n = 2 ** math.floor(math.log2(n_heads))
    m_0 = 2.0 ** (-8.0 / n)
    m = torch.pow(m_0, torch.arange(1, 1 + n))

    if n < n_heads:
        m_hat_0 = 2.0 ** (-4.0 / n)
        m_hat = torch.pow(m_hat_0, torch.arange(1, 1 + 2 * (n_heads - n), 2))
        m = torch.cat([m, m_hat])
        
    return m

@torch.no_grad()
#使用 get_slopes 函数生成的斜率向量 m，并应用于距离矩阵上，用于生成注意力偏置（alibi biases）
def get_alibi_biases(n_heads: int, mask: torch.Tensor):
    m = get_slopes(n_heads).to(mask.device)
    seq_len = mask.size(0)
    distance = torch.tril(torch.arange(0, -seq_len, -1).view(-1, 1).expand(seq_len, seq_len))
    print(distance)

    return distance[:, :, None] * m[None, None, :]



In [3]:
seq_len = 10
n_heads = 8

m = get_slopes(n_heads)
print(m)



tensor([0.5000, 0.2500, 0.1250, 0.0625, 0.0312, 0.0156, 0.0078, 0.0039])


In [4]:
alibi_biases = torch.zeros(seq_len,seq_len)
for j in range(1,seq_len):
    for i in range(j, seq_len):
        alibi_biases[i, i - j] = -j
print(alibi_biases)



tensor([[ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [-1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [-2., -1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [-3., -2., -1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [-4., -3., -2., -1.,  0.,  0.,  0.,  0.,  0.,  0.],
        [-5., -4., -3., -2., -1.,  0.,  0.,  0.,  0.,  0.],
        [-6., -5., -4., -3., -2., -1.,  0.,  0.,  0.,  0.],
        [-7., -6., -5., -4., -3., -2., -1.,  0.,  0.,  0.],
        [-8., -7., -6., -5., -4., -3., -2., -1.,  0.,  0.],
        [-9., -8., -7., -6., -5., -4., -3., -2., -1.,  0.]])


In [5]:
print(alibi_biases[:, :, None].shape, m[None, None, :].shape)


torch.Size([10, 10, 1]) torch.Size([1, 1, 8])


In [6]:
alibi_biases[:, :, None] * m[None, None, :]

tensor([[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00],
