#### repeat and  repeat_interleave
1. repeat 的功能和行为
- 功能: 按照指定的维度重复整个张量。
- 行为: 它会根据传入的参数，在每个维度上重复整个张量的内容。
- 适用场景: 需要扩展张量的维度或重复张量的整体内容。
- tensor.repeat(*sizes) : sizes: 一个整数序列，表示每个维度上重复的次数。

In [28]:
import torch
import torch.nn as nn
x = torch.tensor([[1, 2], [3, 4]])
# 原始张量形状: (2, 2)
x.size()

torch.Size([2, 2])

In [7]:
# 在第一个维度重复2次，第二个维度重复3次
y = x.repeat(2, 3)
# 结果张量形状: (4, 6)
print(y.size())
print(y)
# 输出:
# tensor([[1, 2, 1, 2, 1, 2],
#         [3, 4, 3, 4, 3, 4],
#         [1, 2, 1, 2, 1, 2],
#         [3, 4, 3, 4, 3, 4]])

torch.Size([4, 6])
tensor([[1, 2, 1, 2, 1, 2],
        [3, 4, 3, 4, 3, 4],
        [1, 2, 1, 2, 1, 2],
        [3, 4, 3, 4, 3, 4]])


In [13]:
x

tensor([1, 2, 3])

In [14]:
x.size()

torch.Size([3])

In [15]:
x.repeat(1, 1).size()

torch.Size([1, 3])

In [17]:
x.repeat(1, 1)

tensor([[1, 2, 3]])

In [19]:
x.repeat(3, 4, 2)

tensor([[[1, 2, 3, 1, 2, 3],
         [1, 2, 3, 1, 2, 3],
         [1, 2, 3, 1, 2, 3],
         [1, 2, 3, 1, 2, 3]],

        [[1, 2, 3, 1, 2, 3],
         [1, 2, 3, 1, 2, 3],
         [1, 2, 3, 1, 2, 3],
         [1, 2, 3, 1, 2, 3]],

        [[1, 2, 3, 1, 2, 3],
         [1, 2, 3, 1, 2, 3],
         [1, 2, 3, 1, 2, 3],
         [1, 2, 3, 1, 2, 3]]])

In [21]:
import torch

x = torch.randn(2, 4)


# 1. 沿着某个维度复制
x.repeat(1, 1).size()  # torch.Size([2, 4])

x.repeat(2, 1).size()  # torch.Size([4, 4])

x.repeat(1, 2).size()  # torch.Size([2, 8])


# 2. 不仅可以复制维度, 还可以拓展维度
x.repeat(1, 1, 1).size()  # torch.Size([1, 2, 4])

x.repeat(2, 1, 1).size()  # torch.Size([2, 2, 4])

x.repeat(1, 1, 1, 1).size()  # torch.Size([1, 1, 2, 4])


# 3. repeat中传入的参数不可以少于x的维度
try:
    x.repeat(1)  # 报错
except:
    print(" Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor")

 Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor


2. repeat_interleave 的功能和行为
- 功能: 沿着指定的维度逐元素重复张量的元素。
- 行为: 针对张量的元素逐一操作，按照指定的次数重复每个元素。
- 适用场景: 需要在元素级别对张量进行重复操作。

- tensor.repeat_interleave(repeats, dim=None)
>repeats: 表示每个元素重复的次数，可以是一个整数或者一个与输入张量沿着 dim 的长度一致的张量。
>dim: 指定操作的维度（如果为 None，则会将张量展平成 1D 后再操作）。

In [10]:
x = torch.tensor([1, 2, 3])
y = x.repeat_interleave(2)
# 输出: tensor([1, 1, 2, 2, 3, 3])
y

tensor([1, 1, 2, 2, 3, 3])

In [11]:
x = torch.tensor([[1, 2], [3, 4]])
y = x.repeat_interleave(2, dim=0)
# 输出: tensor([[1, 2],
#               [1, 2],
#               [3, 4],
#               [3, 4]])


In [12]:
x = torch.tensor([1, 2, 3])
y = x.repeat_interleave(torch.tensor([1, 2, 3]))
# 输出: tensor([1, 2, 2, 3, 3, 3])


In [22]:
## torch.repeat更像是把tensor作为一个整体进行复制, 而torch.repeat_interleave更是针对tensor里的每个元素进行复制，并且torch.repeat_interleave

In [50]:
import torch
# (batch_size, seq_len, hidden_dim)
x = torch.randn(3, 5, 16)
print(x)

tensor([[[-1.7303e-01, -1.8239e+00,  1.9584e+00, -2.2530e-03, -8.8339e-01,
           9.2922e-01, -4.7529e-01, -1.1302e+00,  2.1548e+00, -8.1080e-01,
          -6.5028e-01, -6.4566e-01,  3.0621e-01, -1.5821e+00, -8.5193e-01,
           6.4006e-01],
         [ 8.3008e-02, -4.3098e-01, -2.5415e-01, -8.7879e-01,  2.6055e-01,
          -1.8584e+00, -3.7619e-01,  2.9879e-02, -2.8062e-01,  9.4475e-01,
          -4.5170e-01,  2.9434e-01, -2.7260e+00, -2.0521e+00,  8.5321e-01,
           1.3162e-01],
         [ 1.8363e-01,  3.2655e-01,  1.5937e+00, -3.9529e-01,  2.8868e-01,
           9.9687e-01,  1.5748e+00, -2.6170e+00,  1.1463e+00,  6.1534e-01,
           7.3145e-01,  1.3536e+00,  1.9973e-01, -5.2553e-01,  1.1811e+00,
           1.0703e+00],
         [-9.1731e-01, -1.2991e-01, -1.0641e+00,  1.6550e+00, -1.4841e-01,
           8.4696e-01, -8.6567e-01,  1.4960e+00,  9.6555e-01, -1.3881e+00,
          -1.3883e+00, -1.3867e+00,  1.9253e+00, -2.5985e-01,  5.9407e-01,
           8.0398e-01],
    

In [51]:
batch_size, seq_len, hidden_dim = x.size()
nums_head = 8
nums_kv_head = 2
head_dim = hidden_dim // nums_head
print(f"head_dim is {head_dim}")
q_head_per_group =  nums_head // nums_kv_head
print(f"q_head_per_group is {q_head_per_group}")

head_dim is 2
q_head_per_group is 4


In [52]:
q_proj = nn.Linear(hidden_dim, nums_kv_head * head_dim)

In [53]:
q = q_proj(x)
q

tensor([[[-0.4988,  0.0158,  0.4679, -0.0306],
         [ 0.6421, -0.6273,  0.9947, -0.0307],
         [ 0.7347,  0.4412, -1.0284,  0.1776],
         [-1.2146, -0.1403, -0.3550,  0.8192],
         [ 0.2173,  0.0952, -0.6068,  0.5387]],

        [[-0.8430, -0.9999,  0.8334, -0.1843],
         [ 0.4935,  1.0690, -0.4226,  0.3500],
         [-0.3918,  0.8722, -1.6565,  0.7223],
         [ 0.0234,  1.0082, -0.4153, -0.6250],
         [ 0.1691, -0.5942,  0.0031,  0.5466]],

        [[ 0.5519, -0.1868,  1.0101, -0.4999],
         [-0.5555, -0.8184, -0.0819,  0.3828],
         [-0.3636,  0.0228, -1.0644,  0.4446],
         [-0.6011,  0.6844, -0.3876,  0.9775],
         [-1.2093, -0.8693, -0.7047,  0.5780]]], grad_fn=<ViewBackward0>)

In [54]:
q.size()

torch.Size([3, 5, 4])

In [55]:
q1 = q.view(batch_size, seq_len, nums_kv_head, head_dim).transpose(1, 2)
# (batch_size, nums_kv_head, seq_len, head_dim)
q1.size()

torch.Size([3, 2, 5, 2])

In [57]:
q1_repeat = q1.repeat_interleave(q_head_per_group, dim=1)
# (batch_size, nums_kv_head, seq_len, head_dim)
# -> (batch_size, nums_head, seq_len, head_dim)
q1_repeat.size()

torch.Size([3, 8, 5, 2])

In [58]:
q1_repeat

tensor([[[[-0.4988,  0.0158],
          [ 0.6421, -0.6273],
          [ 0.7347,  0.4412],
          [-1.2146, -0.1403],
          [ 0.2173,  0.0952]],

         [[-0.4988,  0.0158],
          [ 0.6421, -0.6273],
          [ 0.7347,  0.4412],
          [-1.2146, -0.1403],
          [ 0.2173,  0.0952]],

         [[-0.4988,  0.0158],
          [ 0.6421, -0.6273],
          [ 0.7347,  0.4412],
          [-1.2146, -0.1403],
          [ 0.2173,  0.0952]],

         [[-0.4988,  0.0158],
          [ 0.6421, -0.6273],
          [ 0.7347,  0.4412],
          [-1.2146, -0.1403],
          [ 0.2173,  0.0952]],

         [[ 0.4679, -0.0306],
          [ 0.9947, -0.0307],
          [-1.0284,  0.1776],
          [-0.3550,  0.8192],
          [-0.6068,  0.5387]],

         [[ 0.4679, -0.0306],
          [ 0.9947, -0.0307],
          [-1.0284,  0.1776],
          [-0.3550,  0.8192],
          [-0.6068,  0.5387]],

         [[ 0.4679, -0.0306],
          [ 0.9947, -0.0307],
          [-1.0284,  0.1776]

In [24]:
print(x.repeat_interleave(2, dim=0))

tensor([[-0.1472, -0.5490],
        [-0.1472, -0.5490],
        [ 0.0124,  1.0488],
        [ 0.0124,  1.0488]])
