# Transformer Attention

“小冬瓜有两把刷子”，中的“刷子”的含义，需要分析上下文语境，才能识别其具体的语义。

这个语义仍有模糊的地方，它不是绝对的表示实体的“刷子”或者“能力”。 “一语双关”正是自然语言分析的复杂之处

语言模型正是要建模出这种模糊的语义。

## 序列建模

word2vec 通过假设窗口局部词汇的关联性，从而构建无监督学习任务来学习 token 的表示。

给定 “刷子” token

- text1: “小冬瓜有两把**刷子**”
- text2: “它用**刷子**画了一只小猫”

在两种语境中， `刷子` 有不同的语义

我们从 embedding 出发, 两句话的 `刷子` 的表示 $E_{刷子}$ 是一样的。

而我们需要通过 **变换** $\mathcal{F}(\cdot):\mathbb{R}^d \rightarrow \mathbb{R}^d$ 实现:

\begin{align}
\mathcal{F}(E_{刷子}) &\rightarrow X_{text1:刷子},\\
\mathcal{F}(E_{刷子}) &\rightarrow X_{text2:刷子},\\
X_\text{text1:刷子} &\neq X_\text{text2:刷子}
\end{align}

In [1]:
import torch 
import torch.nn as nn
import torch.nn.functional as F

vocab_size = 10
seq_len = 8
batch_size = 2
dim = 4


text1 = torch.randint(0, vocab_size - 1, (1, seq_len))[0]
text2 = torch.randint(0, vocab_size - 1, (1, seq_len))[0]
print(text1)
print(text2)

tensor([8, 0, 2, 3, 7, 5, 5, 8])
tensor([0, 5, 5, 4, 8, 7, 6, 5])


In [2]:
idx = 0
text1[idx] = 9
text2[idx] = 9
print(text1)
print(text2)

tensor([9, 0, 2, 3, 7, 5, 5, 8])
tensor([9, 5, 5, 4, 8, 7, 6, 5])


In [3]:
E = nn.Embedding(vocab_size, dim)

In [4]:
seq_embd_1 = E(text1)
seq_embd_2 = E(text2)

print(seq_embd_1[0])
print(seq_embd_2[0])

tensor([ 1.5181, -0.7648, -0.5946, -0.5700], grad_fn=<SelectBackward0>)
tensor([ 1.5181, -0.7648, -0.5946, -0.5700], grad_fn=<SelectBackward0>)


In [5]:
def operation(X):
    X_global = X.mean(dim = 0)
    X = X_global + X
    return X

f1 = operation( seq_embd_1 )[ idx ]
f2 = operation( seq_embd_2 )[ idx ]
print( f1 )
print( f2 )

tensor([ 2.1499, -0.5838, -0.8918, -0.6797], grad_fn=<SelectBackward0>)
tensor([ 2.2314, -0.6370, -0.8608, -0.9149], grad_fn=<SelectBackward0>)


以上，对于位置 $t=0$ 的 context 融合后的特征，是有差别的。

从context-level特征视角：`X = X_global + X` ,  而 X_global 为 序列性的整体表征

词的语义受上下文影响。

序列建模的目标是：找到一种能够高效组合全局信息的方式，来表示词元在语境中的语义

## 高效序列建模

对于文本中 “小冬瓜有两把刷子“， `刷子` 的语义是根据 “语法” 关系表示出来的。

- 语法规则：我们可以通过提取主谓宾定等词性，并分析词与词（或整体）之间的联系，从而表示 `刷子`含义
- 自动化规则：语法规则在于它是复杂的，比如“定语重置”或“倒装句”，中英文语法关系也有区别，期望能够**自动化** 表示 token 在  context 语义：

在复杂的语言模式中，通用性的语言表示，实际上是要自动化的刻画 token 在 context 中的表示。

我们定义权重，来刻画词元之间的语法关系

|                | 小   | 冬   | 瓜   | 有   | 两   | 把   | 刷   | 子   |
| -------------- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- |
| $\overline{w}$ | 1/8  | 1/8  | 1/8  | 1/8  | 1/8  | 1/8  | 1/8  | 1/8  |
| $w$            | 0.1  | 0.1  | 0.5  | 0.01 | 0.01 | 0.08 | 0.1  | 0.1  |
| $w_刷$         | 0.0  | 0.1  | 0.1  | 0.4  | 0.0  | 0.0  | 0.3  | 0.1  |

- $\overline{w}$：每个 token 对全局特征的 重要性 相同
- $w$: 每个 token 对全局特征 重要性 不同。
- $w_刷$: `刷` token 视角，它与其他词元之间的关系。其精细化的表示：一个特异的语法规则联系。 对于其他token也有独立的权重来表示自动化语法规则

再者，对于较长的小说，也许只有关键的情节才能推动故事的发展，即 故事 与 情节有强关联，与其他描写无关或关联小。


|                | 小   | 冬   | 瓜   | 有   | 两   | 把   | 刷   | 子   |
| -------------- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- |
| $w_刷$         | 0.0  | 0.0  | 0.0  | 0.0  | 0.2  | 0.4  | 0.3  | 0.1  |

如果我们刻画了一个错误的权重, `刷子` 就难以表示在 context 中的语义。

$$
S_i = \sum_j w_{ij} X_j
$$

所以难点其实在于 如何建模 “自动化语法规则” 的权重$w_i$，从而提高语义理解。

如下我们定义向量内积，来计算特征向量之间的关联程度。（内积，只是度量相关性的一种实现）

$$
w_{ij} = X_i X_j^T
$$

In [6]:
X = seq_embd_1
i = 0
X_0 = seq_embd_1[i,:].unsqueeze(dim = 0)
print(X.shape)
print(X_0.shape)

## 循环实现
X_weight_i = torch.zeros(1, dim)
w_ij = torch.zeros(seq_len)
for j in range(seq_len):
    w_ij[j] = X[i,:] @ X[j,:].t()
    X_weight_i += w_ij[j] * X[j,:] # weight * feature
print(w_ij/w_ij.max())
print(X_weight_i)


# 矩阵操作实现
s = X_0 @ X.t()
print(s.shape)
X_weight_i = s @ X
print(X_weight_i)

torch.Size([8, 4])
torch.Size([1, 4])
tensor([ 1.0000, -0.5625, -0.2787,  0.9354,  0.2386,  0.7529,  0.7529, -0.4618],
       grad_fn=<DivBackward0>)
tensor([[ 19.6897, -12.4808, -14.2707,   2.0402]], grad_fn=<AddBackward0>)
torch.Size([1, 8])
tensor([[ 19.6897, -12.4808, -14.2707,   2.0402]], grad_fn=<MmBackward0>)


## 注意力机制

上述，已经找出变换 $\mathcal{F}(\cdot):\mathbb{R}^d \rightarrow \mathbb{R}^d$

本质上，我们是在找 “词在序列中的表征”， 

\begin{align}
S_i = \sum_j w_{ij} X_j,\\
w_{ij} = X_i X_j^T
\end{align}

展开式子

\begin{align}
S_i = w_{i1} X_1 + w_{i2} X_2 + \ldots + w_{iN} X_N 
\end{align}

其中有两部分表示 $w_{ij}$ 权重项用于 衡量各 token 的贡献，$X_i$ 即是常规的特征表示。

进一步展开

\begin{align}
S_i = (X_i X_1^T)\cdot X_1 + (X_i X_2^T)\cdot X_2 + \ldots + (X_i X_N^T)\cdot X_N, (X_i X_N^T)\in\mathbb{R}
\end{align}

其中，$X_i X_N^T$ 是标量。我们将每一项都进行线性特征变换


| $(X_i$    | $ X_j^T )$    | $\cdot$ | $X_j$    |
| --------- | ------------- | ------- | -------- |
| $(X_iW_q$ | $(X_jW_k)^T)$ | $\cdot$ | $X_jW_v$ |
| $(Q_i$    | $K^T_j)$      | $\cdot$ | $V_j$    |
| 查询      | 键            | $\cdot$ | 值       |

- query：查询向量，即 词元$i$ 想要知道，他在这个 context 中的表示。即查询词元 $i$ 拿了一把钥匙
- key：键向量，即 context 中的每个词元 就是一个门，所以两个词元之间的相关程度，需要 查询词元拿着钥匙$Q_i$ **访问** 每一道门$K_j$
- value：值向量， 即是门里面内容$V_j$

即一个查询词元$Q_i$，敲开所有的键门$K_j$并访问了内容$V_j$, 综合了所有信息后才有 **融合** 的表示

\begin{align}
S_i &= \sum_j (Q_iK^T_j) V_j,\\
Q_i &= X_iW_Q \\
K_j &= X_jW_K \\
V_j &= X_jW_V \\
\end{align}

其中 $W_Q,W_K,W_V \in \mathbb{R}^{d \times d}$ , 对各个词特征进行投影变换，实现查询、键、值的表示功能。上式即为**注意力机制**，其计算输出称之为 **注意力特征**

那么，原始的计算形式**特征$(x_{i})$加权$(w_{ij})$组合$(\sum)$**，同样是**注意力机制**， 

\begin{align}
S_i = \sum_j w_{ij} X_j,\\
\end{align}

**既然原始的输入向量能够做 注意力计算，为什么还需要投影？**

考虑$w_{ij}$ 为 $w_{刷:}$

|                | 小   | 冬   | 瓜   | 有   | 两   | 把   | 刷   | 子   |
| -------------- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- |
| $w_刷$         | 0.0  | 0.0  | 0.0  | 0.0  | 0.2  | 0.4  | 0.3  | 0.1  |

$w_\text{刷,瓜} = Q_\text{刷}K_\text{瓜} = 0.0 $ 由于这种表示，忽略了 主语“小冬瓜”， 那么就会造成`刷`的语义错误，使得任务出错.

反向传播时调整$W_Q',W_K'$, 从而调整权重 $w_\text{刷,瓜} = Q_\text{刷}'K_\text{瓜}' = 0.2 $，改变注意力特征关系，而$W_v'$ 同理。

引入参数对表征进行投影，实际是希望投影后的特征，**让查询能够高效注意到对预测任务有重要贡献的词元**。



In [7]:
class AttentionSimplest(nn.Module):
    def __init__(self, dim_in, dim_out):
        super().__init__()
        self.WQ = nn.Linear(dim_in, dim_out)
        self.WK = nn.Linear(dim_in, dim_out)
        self.WV = nn.Linear(dim_in, dim_out)
        # W_O 对输出做一层投影，对齐到外部空间
        self.WO = nn.Linear(dim_in, dim_out) 
        
    def forward(self, X, i):
        seq_len, dim = X.shape
        query_i = X[i, :].unsqueeze(dim = 0)
        qi = self.WQ(query_i)
        K = self.WK(X)
        V = self.WV(X)

        attn_i = torch.zeros(1, dim)
        for j in range(seq_len):
            w_ij = qi @ K[j,:].t()
            attn_i += w_ij * V[j,:]

        output = self.WO(attn_i)
        return output

    def forward_basic(self, X, i):
        """
        此版本不做投影变换，仍然能算注意力
        """
        seq_len, dim = X.shape
        qi = X[i, :].unsqueeze(dim = 0)
        K = X
        V = X
        attn_i = torch.zeros(1, dim)
        for j in range(seq_len):
            w_ij = qi @ K[j,:].t()
            attn_i += w_ij * V[j,:]
        return attn_i

X = torch.randn(seq_len, dim)
attn = AttentionSimplest(dim, dim)


# 我们对 每个查询, 都能获得 token_i 在 context 中的语义表示
O_0 = attn(X, 0)
O_1 = attn(X, 1)
print(O_0)
print(O_1)


# 我们对 每个查询, 不经过投影变换
O_0 = attn.forward_basic(X, 0)
O_1 = attn.forward_basic(X, 1)
print(O_0)
print(O_1)

# 其输出，不带 grad_fn.

tensor([[-0.0310,  0.4020, -0.9460, -1.0811]], grad_fn=<AddmmBackward0>)
tensor([[-3.0453e-02, -2.2811e-03,  7.2466e-01,  2.6776e+00]],
       grad_fn=<AddmmBackward0>)
tensor([[ 13.9701, -18.9218,  24.5574,  30.4336]])
tensor([[ -1.8335,  16.1551, -15.2250, -24.2852]])


## 注意力分数归一化

In [8]:
Q_i = torch.randn(1, dim)
K = torch.randn(seq_len, dim)
V = torch.randn(seq_len, dim)

S = Q_i @ K.t()
print(S)

# 归一化处理
P = F.softmax(S, dim = -1)
print(P)
print(P.sum())

O = P @ V

# 归一化处理 softmax
def softmax(X):
    m = torch.max(X)
    X_exp = torch.exp(X - m)
    L = torch.sum(X_exp)
    P = X_exp / L
    return P

P = softmax(S[0,:])
print(P)
print(P.sum())

tensor([[ 0.3505, -1.6537,  2.7238,  0.4472,  1.1096, -0.1954, -2.0742,  0.7644]])
tensor([[0.0578, 0.0078, 0.6209, 0.0637, 0.1236, 0.0335, 0.0051, 0.0875]])
tensor(1.)
tensor([0.0578, 0.0078, 0.6209, 0.0637, 0.1236, 0.0335, 0.0051, 0.0875])
tensor(1.)


## 掩码注意力

我们在之前的例子中，有注意力分数：

|                | 小   | 冬   | 瓜   | 有   | 两   | 把   | 刷   | 子   |
| -------------- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- |
| $w_刷$         | 0.0  | 0.1  | 0.1  | 0.4  | 0.0  | 0.0  | 0.3  | 0.1  |

如果我们人为定义规则，要求不能访问“奇数位置”的键门，即对“K_j, j%2 = 1” 的门进行上锁。

|                | 小   | 冬   | 瓜   | 有   | 两   | 把   | 刷   | 子   |
| -------------- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- |
| $w_刷$         | 0.0  | 0.1  | 0.1  | 0.4  | 0.0  | 0.0  | 0.3  | 0.1  |
| $\text{mask}_刷$         | 1  | 0  | 1  | 0 | 1  | 0  | 1  | 0  |

如果不能开键（key）门, 那么也不能访问内容（value）, 同理我们对每个 token 都有独立的 $\text{mask}_i$ 锁


$$
\begin{align}
S_i = \sum_j \textcolor{red}{\text{mask}_{ij}} w_{ij} X_j,\\
w_{ij} = X_i X_j^T
\end{align}
$$

In [9]:
mask = torch.zeros(1,seq_len)
idx = torch.arange(1, seq_len, 2)
print(idx)
mask[0, idx] = 1
print(mask)

tensor([1, 3, 5, 7])
tensor([[0., 1., 0., 1., 0., 1., 0., 1.]])


In [10]:
# 掩码注意力
Q_i = torch.randn(1, dim)
K = torch.randn(seq_len, dim)
V = torch.randn(seq_len, dim)

S = Q_i @ K.t()
print(S)

# 增加mask
S_mask = S * mask
print(S)

# 归一化处理
P = F.softmax(S_mask, dim = -1)
print(P) # 有问题, mask 掉的 token 仍有 访问权重
print(P.sum())

tensor([[-0.0627,  0.9994,  0.7831,  0.2163, -2.1494,  1.6317,  1.8986,  1.0182]])
tensor([[-0.0627,  0.9994,  0.7831,  0.2163, -2.1494,  1.6317,  1.8986,  1.0182]])
tensor([[0.0631, 0.1715, 0.0631, 0.0784, 0.0631, 0.3228, 0.0631, 0.1748]])
tensor(1.)


In [11]:
idx = torch.arange(0, seq_len, 2)
print(idx)

# 增加mask
S_inf_mask = S.clone()
S_inf_mask[0,idx] = -10000.0 # 在 mask 掉的分数置为 负无穷($-\infty$)
print(S_inf_mask)

# 归一化处理
P = F.softmax(S_inf_mask, dim = -1)
print(P) # 没问题
print(P.sum())

tensor([0, 2, 4, 6])
tensor([[-1.0000e+04,  9.9941e-01, -1.0000e+04,  2.1627e-01, -1.0000e+04,
          1.6317e+00, -1.0000e+04,  1.0182e+00]])
tensor([[0.0000, 0.2295, 0.0000, 0.1049, 0.0000, 0.4318, 0.0000, 0.2338]])
tensor(1.)


## 点积缩放注意力(ScaledDotProductAttention)

In [12]:
mask = torch.tril(torch.ones(1,5,5))
print(mask)
torch.where(mask==0)

tensor([[[1., 0., 0., 0., 0.],
         [1., 1., 0., 0., 0.],
         [1., 1., 1., 0., 0.],
         [1., 1., 1., 1., 0.],
         [1., 1., 1., 1., 1.]]])


(tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
 tensor([0, 0, 0, 0, 1, 1, 1, 2, 2, 3]),
 tensor([1, 2, 3, 4, 2, 3, 4, 3, 4, 4]))

In [13]:
Q = torch.randn(2, 5, dim) # 1 is batchsize
K = torch.randn(2, 5, dim)
V = torch.randn(2, 5, dim)

S = Q[0, -1, :] @ K[0, :, :].transpose(0, 1) # batchsize 1, 最后一个 token 作为检索
print(S)

S = Q[0, :, :] @ K[0, :, :].transpose(0, 1) # 序列做 attention,
print(S)

S = Q @ K.transpose(1, 2) # 序列做 attention, batch 并行
print(S)

tensor([ 0.0664,  1.7814,  2.5556, -0.5928, -0.5491])
tensor([[ 1.9017,  0.1047,  3.4827, -2.1984, -0.7388],
        [-0.8771,  0.1031, -1.4690, -0.4169,  0.3927],
        [ 2.2289, -1.5629, -4.0477, -0.0432,  2.9467],
        [-0.4769,  0.6820, -0.7752, -0.0132,  0.5000],
        [ 0.0664,  1.7814,  2.5556, -0.5928, -0.5491]])
tensor([[[ 1.9017,  0.1047,  3.4827, -2.1984, -0.7388],
         [-0.8771,  0.1031, -1.4690, -0.4169,  0.3927],
         [ 2.2289, -1.5629, -4.0477, -0.0432,  2.9467],
         [-0.4769,  0.6820, -0.7752, -0.0132,  0.5000],
         [ 0.0664,  1.7814,  2.5556, -0.5928, -0.5491]],

        [[ 4.3100,  2.1476, -2.4445, -0.3642, -0.8997],
         [ 2.0166,  0.8142, -0.0494,  0.3950,  0.1620],
         [-2.8694, -4.4273, -0.4835, -1.7465, -1.0318],
         [-1.9456, -2.2504, -0.9859, -1.3685, -0.8618],
         [ 3.3065,  5.8300, -1.3517,  1.0775,  0.5839]]])


In [14]:
## 完整注意力
import math

class ScaleDotProductAttention(nn.Module):
    def __init__(self, dim_in, dim_out):
        super().__init__()
        self.WQ = nn.Linear(dim_in, dim_out)
        self.WK = nn.Linear(dim_in, dim_out)
        self.WV = nn.Linear(dim_in, dim_out)
        self.WO = nn.Linear(dim_in, dim_out) 
        
    def forward(self, X, mask = None):
        batch_size, seq_len, dim = X.shape
        Q = self.WQ(X)
        K = self.WK(X)
        V = self.WV(X)

        # 多个 q_i 计算注意力特征
        S = Q @ K.transpose(1,2) / math.sqrt(dim) # 1. 为什么要除于 \sqrt{d}

        if mask is not None:
            idx =torch.where(mask==0)
            S[idx[0],idx[1],idx[2]] = -10000.0
        
        P = torch.softmax(S, dim = -1) # 行 softmax
        Z = P @ V
        output = self.WO(Z)
        
        return output

X = torch.randn(64, seq_len, dim)
# mask = torch.tril(torch.ones(64, seq_len, dim))
mask = torch.ones(64, seq_len, dim)
model = ScaleDotProductAttention(dim, dim)
Y = model(X, mask)
print(Y.shape)

torch.Size([64, 8, 4])


In [15]:
## 实例：基于 Attention 的文本分类
import torch.optim as optim

# 数据
bs = 32
seq_len = 100
dim = 512
vocab_size = 26
class_num = 3 # 情感分类: negative, neutral, positive


class AttentionClassifyModel(nn.Module):
    def __init__(self, dim = 512, vocab_size = 100, class_num = 2):
        super().__init__()
        self.dim = dim
        self.vocab_size = vocab_size
        self.class_num = class_num
        self.E = nn.Embedding(vocab_size, dim)
        self.attention = ScaleDotProductAttention(dim, dim)
        self.head = nn.Linear(dim, class_num)
        
    def forward(self, X, mask):
        bs,seq_len = X.shape
        X = self.E(X)
        X = self.attention(X, mask)
        h = X.mean(dim = 1) # sequence dimention, mean pooling
        Y = self.head(h)
        return Y # logits
        
input_ids = torch.randint(0, vocab_size, size=(bs, seq_len)) # 语料
Y = torch.randint(0, class_num, size=(1, bs))[0]

model = AttentionClassifyModel(dim, vocab_size, class_num)
optimizer = optim.SGD(model.parameters(), lr = 1e-2)
loss_fn = nn.CrossEntropyLoss()
print(model)

AttentionClassifyModel(
  (E): Embedding(26, 512)
  (attention): ScaleDotProductAttention(
    (WQ): Linear(in_features=512, out_features=512, bias=True)
    (WK): Linear(in_features=512, out_features=512, bias=True)
    (WV): Linear(in_features=512, out_features=512, bias=True)
    (WO): Linear(in_features=512, out_features=512, bias=True)
  )
  (head): Linear(in_features=512, out_features=3, bias=True)
)


In [16]:
for i in range(1000):
    optimizer.zero_grad()
    Y_pred = model(input_ids, None)

    loss = loss_fn(Y_pred, Y)   
    if i % 100 == 0:
        print(loss)
    loss.backward()
    optimizer.step()

tensor(1.0939, grad_fn=<NllLossBackward0>)
tensor(1.0038, grad_fn=<NllLossBackward0>)
tensor(0.9579, grad_fn=<NllLossBackward0>)
tensor(0.8850, grad_fn=<NllLossBackward0>)
tensor(0.7633, grad_fn=<NllLossBackward0>)
tensor(0.5938, grad_fn=<NllLossBackward0>)
tensor(0.4109, grad_fn=<NllLossBackward0>)
tensor(0.2699, grad_fn=<NllLossBackward0>)
tensor(0.1797, grad_fn=<NllLossBackward0>)
tensor(0.1213, grad_fn=<NllLossBackward0>)


## 实例2: 基于 Attention 的词元预测

In [17]:
input_ids = torch.randint(0, vocab_size, size=(bs, seq_len)) # 语料
print(input_ids)
torch.roll(input_ids, shifts = -1) # input seq维度左移动一位作为 label

tensor([[18,  8,  7,  ...,  7, 12,  9],
        [21, 12,  2,  ..., 10,  2, 24],
        [ 0, 23, 17,  ..., 20, 20,  1],
        ...,
        [ 2,  1,  3,  ...,  7,  2, 22],
        [ 9, 23, 15,  ..., 24, 16, 13],
        [16, 16, 21,  ...,  4, 11,  6]])


tensor([[ 8,  7, 11,  ..., 12,  9, 21],
        [12,  2, 22,  ...,  2, 24,  0],
        [23, 17,  4,  ..., 20,  1, 23],
        ...,
        [ 1,  3, 20,  ...,  2, 22,  9],
        [23, 15, 22,  ..., 16, 13, 16],
        [16, 21, 23,  ..., 11,  6, 18]])

In [18]:
## 实例2: 基于 Attention 的词元预测
import torch.optim as optim

# 数据
bs = 32
seq_len = 100
dim = 512
vocab_size = 26
class_num = 3 # 情感分类: negative, neutral, positive


class AttentionLanguageModel(nn.Module):
    def __init__(self, dim = 512, vocab_size = 100):
        super().__init__()
        self.dim = dim
        self.vocab_size = vocab_size
        self.class_num = class_num
        self.E = nn.Embedding(vocab_size, dim)
        self.attention = ScaleDotProductAttention(dim, dim)
        self.head = nn.Linear(dim, vocab_size)
        
    def forward(self, X, mask):
        bs,seq_len = X.shape
        X = self.E(X)
        X = self.attention(X, mask)
        # h = X.mean(dim = 1) # sequence dimention, mean pooling
        Y = self.head(X)
        return Y # logits
        
input_ids = torch.randint(0, vocab_size, size=(bs, seq_len)) # 语料
Y = torch.roll(input_ids.clone(), shifts = -1)

model = AttentionLanguageModel(dim, vocab_size)
optimizer = optim.SGD(model.parameters(), lr = 1e-2)
loss_fn = nn.CrossEntropyLoss()
print(model)

AttentionLanguageModel(
  (E): Embedding(26, 512)
  (attention): ScaleDotProductAttention(
    (WQ): Linear(in_features=512, out_features=512, bias=True)
    (WK): Linear(in_features=512, out_features=512, bias=True)
    (WV): Linear(in_features=512, out_features=512, bias=True)
    (WO): Linear(in_features=512, out_features=512, bias=True)
  )
  (head): Linear(in_features=512, out_features=26, bias=True)
)


In [19]:
mask = torch.tril(torch.ones(bs, seq_len, seq_len))
Y_pred = model(input_ids, mask)
print(input_ids.shape)
print(Y_pred.shape)

torch.Size([32, 100])
torch.Size([32, 100, 26])


In [20]:
for i in range(1000):
    optimizer.zero_grad()
    Y_pred = model(input_ids, mask)

    loss = loss_fn(Y_pred.view(bs*seq_len, vocab_size), 
                   Y.view(bs*seq_len), ) 
    
    if i % 100 == 0:
        print(loss)
    loss.backward()
    optimizer.step()

tensor(3.2611, grad_fn=<NllLossBackward0>)
tensor(3.2556, grad_fn=<NllLossBackward0>)
tensor(3.2516, grad_fn=<NllLossBackward0>)
tensor(3.2482, grad_fn=<NllLossBackward0>)
tensor(3.2452, grad_fn=<NllLossBackward0>)
tensor(3.2423, grad_fn=<NllLossBackward0>)
tensor(3.2394, grad_fn=<NllLossBackward0>)
tensor(3.2364, grad_fn=<NllLossBackward0>)
tensor(3.2332, grad_fn=<NllLossBackward0>)
tensor(3.2298, grad_fn=<NllLossBackward0>)


## 多头注意力机制

注意力权重可能是稀疏的（如 seq_len 为 1024）, 可能有极少数的 token 有 重要权重

这会导致有信息发生遗漏

如何可以更多样的注意力关系？

- 注意力权重，是关心分类任务的
- 注意力权重，是关心语法分析的

以下示例，就扩展了 **单头** 注意力范围

|                | 小   | 冬   | 瓜   | 有   | 两   | 把   | 刷   | 子   |
| -------------- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- |
| $w^{(1)}_刷$    | 0.0  | 0.0  | 0.0  | 0.0  | 0.2  | 0.4  | 0.3  | 0.1  |
| $w^{(2)}_刷$    | 0.1  | 0.05  | 0.4.5  | 0.1  | 0.2  | 0.0  | 0.1  | 0.0  |


\begin{align}
S^{(h)}_i = \sum_j w^{(h)}_{ij} {V_j^{(h)}}^T,\\
w^{(h)}_{ij} = Q^{(h)}_i {K^{(h)}_j}^T
\end{align}

其中, $ \texttt{Split}(Q) \rightarrow Q^{(1)},Q^{(2)}\ldots Q^{(H)} \in \mathbb {R} ^ {N \times d'}, d' = d / H $

In [21]:
bs = 2
seq_len = 5
dim = 512
heads = 8
head_dim = dim // heads

Q = torch.randn(bs, seq_len, dim)
K = torch.randn(bs, seq_len, dim)
V = torch.randn(bs, seq_len, dim)

In [22]:
# single head

S = Q @ K.transpose(1,2) / math.sqrt(dim)
P = torch.softmax(S, dim = -1)
Z = P @ V

print(S.shape)
print(Z.shape)

torch.Size([2, 5, 5])
torch.Size([2, 5, 512])


In [23]:
# multi head 切分，尝试单头版本
print(Q.shape)
Q_h = Q.view(bs, seq_len, heads, head_dim)
K_h = K.view(bs, seq_len, heads, head_dim)
V_h = V.view(bs, seq_len, heads, head_dim)
print(head_dim)
print(Q_h.shape)


Q_0 = Q_h[:, :, 0, :] # head 0
K_0 = K_h[:, :, 0, :] # head 0
V_0 = V_h[:, :, 0, :] # head 0
print(Q_0.shape)

S = Q_0 @ K_0.transpose(1,2) / math.sqrt( head_dim ) # 注意除单头的维度
P = torch.softmax(S, dim = -1)
Z_0 = P @ V_0
print(S.shape)
print(Z.shape) 

torch.Size([2, 5, 512])
64
torch.Size([2, 5, 8, 64])
torch.Size([2, 5, 64])
torch.Size([2, 5, 5])
torch.Size([2, 5, 512])


In [24]:
# multi head, 头并行
print(Q_h.shape)
print(Q_h.transpose(1,2).shape) # 这个变换后两个维度 为 seq, dim, 要把 head 提前处理
print(K_h.transpose(1,2).transpose(2,3).shape)

Q_h = Q_h.transpose(1,2)
K_h = K_h.transpose(1,2)
V_h = V_h.transpose(1,2)

S = Q_h @ K_h.transpose(2,3) / math.sqrt( head_dim ) # 注意除单头的维度
P = torch.softmax(S, dim = -1)
Z = P @ V_h
print(S.shape)
print(Z.shape)


Z = Z.transpose(1,2).reshape(bs, seq_len, dim)
print(Z.shape)

torch.Size([2, 5, 8, 64])
torch.Size([2, 8, 5, 64])
torch.Size([2, 8, 64, 5])
torch.Size([2, 8, 5, 5])
torch.Size([2, 8, 5, 64])
torch.Size([2, 5, 512])


In [25]:
## 多头注意力实现
import math

class MultiHeadScaleDotProductAttention(nn.Module):
    def __init__(self, dim_in, dim_out, heads = 8):
        super().__init__()
        self.WQ = nn.Linear(dim_in, dim_out)
        self.WK = nn.Linear(dim_in, dim_out)
        self.WV = nn.Linear(dim_in, dim_out)
        self.WO = nn.Linear(dim_in, dim_out)
        self.heads = 8
        self.head_dim = dim_out // self.heads
        
    def forward(self, X, mask = None):
        batch_size, seq_len, dim = X.shape
        Q = self.WQ(X)
        K = self.WK(X)
        V = self.WV(X)

        # 拆分维度
        Q_h = Q.view(bs, seq_len, self.heads, self.head_dim).transpose(1,2)
        K_h = K.view(bs, seq_len, self.heads, self.head_dim).transpose(1,2)
        V_h = V.view(bs, seq_len, self.heads, self.head_dim).transpose(1,2)

        # 多个 q_i 计算注意力特征
        S = Q_h @ K_h.transpose(2,3) / math.sqrt(self.head_dim) # 1. 为什么要除于 \sqrt{d}

        # 请判断以下 mask 代码是否正确?
        if mask is not None:
            idx = torch.where(mask==0)
            S[:, idx[0],idx[1],idx[2]] = -10000.0
        
        P = torch.softmax(S, dim = -1) # 行 softmax
        Z = P @ V_h

        # 恢复维度
        Z = Z.transpose(1,2).reshape(bs, seq_len, dim)
        
        output = self.WO(Z)
        
        return output

X = torch.randn(2, seq_len, dim)
mask = torch.ones(2, seq_len, dim)
model = MultiHeadScaleDotProductAttention(dim, dim, 8)
Y = model(X, mask)
print(Y.shape)

torch.Size([2, 5, 512])


In [26]:
input_ids = torch.tensor([[1, 2, 3, 0, 0],
                          [1, 2, 0, 0, 0]], dtype = torch.long) # 0 is pad
bs, seq_len = input_ids.shape
mask = torch.ones(bs, seq_len, seq_len)

for i in range(bs):
    pad_idx =  torch.where(input_ids[i, :]  == 0)[0]
    mask[i, pad_idx, :] = 0
    mask[i, :, pad_idx] = 0
print(mask)

score = torch.randn(bs, seq_len, seq_len)
score * mask

tensor([[[1., 1., 1., 0., 0.],
         [1., 1., 1., 0., 0.],
         [1., 1., 1., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]],

        [[1., 1., 0., 0., 0.],
         [1., 1., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]]])


tensor([[[ 0.3580, -2.1688,  0.4910,  0.0000,  0.0000],
         [ 1.6233,  0.3894, -0.1632, -0.0000,  0.0000],
         [-0.2142, -0.2808, -0.7808,  0.0000,  0.0000],
         [-0.0000, -0.0000, -0.0000,  0.0000,  0.0000],
         [ 0.0000, -0.0000, -0.0000, -0.0000,  0.0000]],

        [[ 0.2749, -1.1813, -0.0000, -0.0000, -0.0000],
         [-0.4358,  1.2782, -0.0000,  0.0000,  0.0000],
         [-0.0000,  0.0000, -0.0000, -0.0000,  0.0000],
         [-0.0000, -0.0000, -0.0000, -0.0000,  0.0000],
         [ 0.0000, -0.0000, -0.0000,  0.0000,  0.0000]]])

## 实例3: 基于multi-head-Attention 的词元预测

In [27]:
import torch.optim as optim

# 数据
bs = 32
seq_len = 100
dim = 512
vocab_size = 26
class_num = 3 # 情感分类: negative, neutral, positive


class MultiHeadAttentionLanguageModel(nn.Module):
    def __init__(self, dim = 512, vocab_size = 100):
        super().__init__()
        self.dim = dim
        self.vocab_size = vocab_size
        self.class_num = class_num
        self.E = nn.Embedding(vocab_size, dim)
        self.attention = MultiHeadScaleDotProductAttention(dim, dim) # 仅改变这个函数
        self.head = nn.Linear(dim, vocab_size)
        
    def forward(self, X, mask):
        bs,seq_len = X.shape
        X = self.E(X)
        X = self.attention(X, mask)
        Y = self.head(X)
        return Y # logits

input_ids = torch.randint(0, vocab_size, size=(bs, seq_len)) # 语料
Y = torch.roll(input_ids, shifts = -1) # input seq维度左移动一位作为 label

model = AttentionLanguageModel(dim, vocab_size)
optimizer = optim.SGD(model.parameters(), lr = 1e-2)
loss_fn = nn.CrossEntropyLoss()
print(model)

AttentionLanguageModel(
  (E): Embedding(26, 512)
  (attention): ScaleDotProductAttention(
    (WQ): Linear(in_features=512, out_features=512, bias=True)
    (WK): Linear(in_features=512, out_features=512, bias=True)
    (WV): Linear(in_features=512, out_features=512, bias=True)
    (WO): Linear(in_features=512, out_features=512, bias=True)
  )
  (head): Linear(in_features=512, out_features=26, bias=True)
)


In [28]:
for i in range(1000):
    optimizer.zero_grad()
    Y_pred = model(input_ids, mask)

    loss = loss_fn(Y_pred.view(bs*seq_len, vocab_size), 
                   Y.view(bs*seq_len), ) 
    
    if i % 100 == 0:
        print(loss)
    loss.backward()
    optimizer.step()

tensor(3.2587, grad_fn=<NllLossBackward0>)
tensor(3.2563, grad_fn=<NllLossBackward0>)
tensor(3.2544, grad_fn=<NllLossBackward0>)
tensor(3.2528, grad_fn=<NllLossBackward0>)
tensor(3.2514, grad_fn=<NllLossBackward0>)
tensor(3.2501, grad_fn=<NllLossBackward0>)
tensor(3.2488, grad_fn=<NllLossBackward0>)
tensor(3.2474, grad_fn=<NllLossBackward0>)
tensor(3.2459, grad_fn=<NllLossBackward0>)
tensor(3.2444, grad_fn=<NllLossBackward0>)


## MultiHeadAttention Backward(*)

### 单头版本 pytorch 自动求导

In [29]:
X = torch.randn(2, 3, requires_grad = True) #  seq_len, dim 
X.retain_grad()

Wq = torch.randn(3, 4, requires_grad = True)
Wk = torch.randn(3, 4, requires_grad = True)
Wv = torch.randn(3, 4, requires_grad = True)
Wo = torch.randn(4, 3, requires_grad = True)

Q = X @ Wq
K = X @ Wk
V = X @ Wv
Q.retain_grad()
K.retain_grad()
V.retain_grad()

S = Q @ K.t() / math.sqrt(4)
S.retain_grad()

P = F.softmax(S, dim = -1)
P.retain_grad()

Z = P @ V
Z.retain_grad()
print(Z.shape)

O = Z @ Wo
O.retain_grad()
print(O.shape)

torch.Size([2, 4])
torch.Size([2, 3])


In [30]:
Y = torch.randn(2, 3)
loss_fn = nn.MSELoss()
loss = loss_fn(O, Y)
print(loss)
loss.backward()

print(X.grad)

tensor(6.1211, grad_fn=<MseLossBackward0>)
tensor([[-2.1932, -2.0092, -3.4629],
        [ 9.5022,  1.1678, -5.1825]])


In [31]:
## 单头版本 手动 求导

dO = (1/O.numel()) * 2 * (O - Y) # 2, 3, 4
print(Wo.shape, dO.shape, Z.shape)
# O = Z @ Wo

dWo = (dO.t() @ Z)
dZ = dO @ Wo.t()

torch.Size([4, 3]) torch.Size([2, 3]) torch.Size([2, 4])


In [32]:
dP = dZ @ V.t()
dV = P.t() @ dZ

In [33]:
dS = torch.zeros_like(dP)

for i in range(2):
    dP_dS_i = torch.diag(P[i,:]) - torch.outer(P[i,:] , P[i,:])
    dS[i,:] = dP[i,:] @ dP_dS_i
print(dS.shape)

dS = dS 
print(dS)
print(S.grad)

torch.Size([2, 2])
tensor([[ 0.0178, -0.0178],
        [ 3.0729, -3.0729]], grad_fn=<CopySlices>)
tensor([[ 0.0178, -0.0178],
        [ 3.0729, -3.0729]])


In [34]:
# Q @ K.t() = 2,4
# Q 2x3
# K 4x3, K.t()= 3x4
# S = 2x4

dQ = dS @ K / math.sqrt(4)
dK = dS.t() @ Q /math.sqrt(4)

In [35]:
print(dQ, Q.grad)
print(dK, K.grad)
print(dV, V.grad)

tensor([[ 0.0132, -0.0165, -0.0147,  0.0300],
        [ 2.2869, -2.8550, -2.5510,  5.1993]], grad_fn=<DivBackward0>) tensor([[ 0.0132, -0.0165, -0.0147,  0.0300],
        [ 2.2869, -2.8550, -2.5510,  5.1993]])
tensor([[ 0.5142, -0.0055, -0.3759, -0.1285],
        [-0.5142,  0.0055,  0.3759,  0.1285]], grad_fn=<DivBackward0>) tensor([[ 0.5142, -0.0055, -0.3759, -0.1285],
        [-0.5142,  0.0055,  0.3759,  0.1285]])
tensor([[-1.8593,  0.9742, -1.8910,  1.3738],
        [-2.4568,  1.3461, -0.9550,  0.2883]], grad_fn=<MmBackward0>) tensor([[-1.8593,  0.9742, -1.8910,  1.3738],
        [-2.4568,  1.3461, -0.9550,  0.2883]])


In [36]:
dX_dQ = dQ @ Wq.t()
dX_dK = dK @ Wk.t()
dX_dV = dV @ Wv.t()
dX = dX_dQ + dX_dK + dX_dV
print(dX)
print(dX)

tensor([[-2.1932, -2.0092, -3.4629],
        [ 9.5022,  1.1678, -5.1825]], grad_fn=<AddBackward0>)
tensor([[-2.1932, -2.0092, -3.4629],
        [ 9.5022,  1.1678, -5.1825]], grad_fn=<AddBackward0>)


In [37]:
dWq = X.t() @ dQ
dWk = X.t() @ dK
dWv = X.t() @ dV
print(dWq)
print(Wq.grad)

tensor([[-0.1109,  0.1385,  0.1237, -0.2522],
        [ 1.1945, -1.4913, -1.3325,  2.7158],
        [-0.2443,  0.3049,  0.2725, -0.5553]], grad_fn=<MmBackward0>)
tensor([[-0.1109,  0.1385,  0.1237, -0.2522],
        [ 1.1945, -1.4913, -1.3325,  2.7158],
        [-0.2443,  0.3049,  0.2725, -0.5553]])


### 多头版本多batch，注意力梯度求导（*）