In [1]:
import torch
from torch import nn
from d2l import torch as d2l
import math
import pandas as pd

In [2]:
def transpose_qkv(X, num_heads):
    """
        为了多注意力头的并行计算而变换形状
            将X的维度切断后，分给不同的头处理。
            
    """
    # 输入X的形状:(batch_size，查询或者“键－值”对的个数，num_hiddens)
    # 输出X的形状:(batch_size，查询或者“键－值”对的个数，num_heads，num_hiddens/num_heads)
    X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)

    # 输出X的形状:(batch_size，num_heads，查询或者“键－值”对的个数,num_hiddens/num_heads)
    X = X.permute(0, 2, 1, 3)

    # 最终输出的形状:(batch_size*num_heads,查询或者“键－值”对的个数,num_hiddens/num_heads)
    return X.reshape(-1, X.shape[2], X.shape[3])

In [3]:
a = torch.randn((2,3,9))

In [4]:
a

tensor([[[ 0.6553, -0.9672, -0.3537,  0.6736,  0.1459,  0.3441, -0.4819,
          -2.2395,  1.5547],
         [-0.5670, -0.3486, -0.7059, -0.1144,  1.5583, -0.6897,  0.7415,
          -1.0818, -0.5347],
         [ 1.1206,  0.4609, -1.1403, -0.8820, -1.4045,  0.1678,  0.3981,
          -0.0918,  0.5389]],

        [[-0.8795, -1.1234, -1.2054,  1.9456, -0.2745,  0.0574, -0.6853,
          -0.6875,  0.3098],
         [-0.6960, -2.0392,  0.3561,  0.4584,  1.0259, -1.3871,  1.2355,
           0.2464,  0.9230],
         [ 1.5761, -0.0202, -0.3672, -0.7423,  0.2101, -2.4140, -0.1120,
          -0.1747,  0.1859]]])

In [5]:
d2l.transpose_output(transpose_qkv(a,3),3)

tensor([[[ 0.6553, -0.9672, -0.3537,  0.6736,  0.1459,  0.3441, -0.4819,
          -2.2395,  1.5547],
         [-0.5670, -0.3486, -0.7059, -0.1144,  1.5583, -0.6897,  0.7415,
          -1.0818, -0.5347],
         [ 1.1206,  0.4609, -1.1403, -0.8820, -1.4045,  0.1678,  0.3981,
          -0.0918,  0.5389]],

        [[-0.8795, -1.1234, -1.2054,  1.9456, -0.2745,  0.0574, -0.6853,
          -0.6875,  0.3098],
         [-0.6960, -2.0392,  0.3561,  0.4584,  1.0259, -1.3871,  1.2355,
           0.2464,  0.9230],
         [ 1.5761, -0.0202, -0.3672, -0.7423,  0.2101, -2.4140, -0.1120,
          -0.1747,  0.1859]]])

In [6]:
a = torch.randn(1, 5,3)

In [7]:
a

tensor([[[-0.0127, -0.2887,  1.5477],
         [-1.5060,  0.9637,  0.9685],
         [-0.9791, -1.1097,  0.1590],
         [-0.0262,  0.3039,  0.9592],
         [-0.5225,  0.4729, -0.2219]]])

In [8]:
a[:,:2,:]

tensor([[[-0.0127, -0.2887,  1.5477],
         [-1.5060,  0.9637,  0.9685]]])

In [12]:
a = torch.randn((2,10,2))
a

tensor([[[-0.5073, -0.4896],
         [ 0.1574,  0.0384],
         [ 0.6590,  0.0907],
         [-0.2328,  3.1810],
         [-0.3632,  0.3216],
         [-1.3317, -0.1453],
         [-3.0024,  0.8046],
         [-0.9502, -0.8820],
         [-0.3713, -0.2523],
         [-0.5856,  1.9133]],

        [[ 1.2407,  0.7914],
         [ 1.0452, -0.9501],
         [-0.5046, -0.5260],
         [-1.1320, -0.8667],
         [ 0.2222, -1.1028],
         [-0.9253,  0.0304],
         [-0.2997,  0.4115],
         [-0.3173,  1.0592],
         [ 0.5815,  0.1927],
         [-1.3292,  2.0807]]])

In [14]:
a[[0,0,0,1,1,1],torch.tensor([1,2,3,7,8,9])].reshape(2,3,-1)

tensor([[[ 0.1574,  0.0384],
         [ 0.6590,  0.0907],
         [-0.2328,  3.1810]],

        [[-0.3173,  1.0592],
         [ 0.5815,  0.1927],
         [-1.3292,  2.0807]]])