In [6]:
import torch

# 原始张量: (2, 6) = 12个元素
x = torch.arange(12).reshape(2, 6)
print("原始形状:", x.shape)
print(x)
# tensor([[ 0,  1,  2,  3,  4,  5],
#         [ 6,  7,  8,  9, 10, 11]])

# view 重塑为 (2, 3, 2) = 12个元素
y = x.view(2, 3, 2)
print("\nview后形状:", y.shape)
print(y)
# tensor([[[ 0,  1],
#          [ 2,  3],
#          [ 4,  5]],
#         [[ 6,  7],
#          [ 8,  9],
#          [10, 11]]])

# 使用 -1 自动推断维度
z = x.view(1, -1)  # -1 自动计算为 4
print("\n自动推断:", z.shape)  # torch.Size([3, 4])
print(z)

原始形状: torch.Size([2, 6])
tensor([[ 0,  1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10, 11]])

view后形状: torch.Size([2, 3, 2])
tensor([[[ 0,  1],
         [ 2,  3],
         [ 4,  5]],

        [[ 6,  7],
         [ 8,  9],
         [10, 11]]])

自动推断: torch.Size([1, 12])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11]])


In [7]:
# 原始张量: (2, 6) = 12个元素
x = torch.arange(12).view(2, 6)
print("原始形状:", x.shape)
print(x)

原始形状: torch.Size([2, 6])
tensor([[ 0,  1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10, 11]])


In [11]:
y = x.view(6,2)
print(y)

tensor([[ 0,  1],
        [ 2,  3],
        [ 4,  5],
        [ 6,  7],
        [ 8,  9],
        [10, 11]])


In [15]:
# ---- 多头注意力：将隐藏维度拆分成多头 ----
# 假设 batch_size=4, seq_len=5, hidden_dim=16, num_heads=4
batch_size, seq_len, hidden_dim, num_heads = 2, 3, 4, 2
x = torch.randn(batch_size, seq_len, hidden_dim)
print(x)

# 拆分成多头 (多头的 head_dim)
head_dim = hidden_dim // num_heads
x_heads = x.view(batch_size, seq_len, num_heads, head_dim)
print(x_heads.shape)
print(x_heads)

# 或 permute 排到常见 shape: (batch, num_heads, seq, head_dim)
x_heads = x_heads.permute(0, 2, 1, 3)
print("多头 shape:", x_heads.shape)  # (4, 4, 5, 4)

# ---- CNN 展平特征图送入全连接 ----
# 假设 feature map: (batch_size, channels, height, width)
batch_size, channels, height, width = 8, 32, 7, 7
feature_map = torch.randn(batch_size, channels, height, width)

# 展平特征 (通常送入全连接层)
flattened = feature_map.view(batch_size, -1)  # (8, 1568)
print("展平特征 shape:", flattened.shape)


tensor([[[-0.0519,  0.3264,  0.6561,  1.0575],
         [ 0.8724, -1.1942, -0.5756, -0.2264],
         [-0.6480, -0.2886,  0.6631,  0.8495]],

        [[ 0.4750,  0.8697,  0.1260, -0.5719],
         [ 1.3796,  1.7830, -1.0428,  0.0322],
         [ 1.7248, -0.1007, -0.0318, -1.6944]]])
torch.Size([2, 3, 2, 2])
tensor([[[[-0.0519,  0.3264],
          [ 0.6561,  1.0575]],

         [[ 0.8724, -1.1942],
          [-0.5756, -0.2264]],

         [[-0.6480, -0.2886],
          [ 0.6631,  0.8495]]],


        [[[ 0.4750,  0.8697],
          [ 0.1260, -0.5719]],

         [[ 1.3796,  1.7830],
          [-1.0428,  0.0322]],

         [[ 1.7248, -0.1007],
          [-0.0318, -1.6944]]]])
多头 shape: torch.Size([2, 2, 3, 2])
展平特征 shape: torch.Size([8, 1568])
