In [20]:
import torch

def create_causal_mask(seq_length: int) -> torch.Tensor:
    """
    创建因果掩码（Causal Mask）
    参数:
        seq_length: 序列长度
    返回:
        mask: 形状为 (seq_length, seq_length) 的掩码张量
             0 表示允许注意力，1 表示屏蔽注意力
    """
    # 创建上三角矩阵（不包含对角线）
    mask = torch.triu(torch.ones(seq_length, seq_length), diagonal=1)
    # 将掩码转换为布尔类型
    mask = mask.bool()
    return mask

# 测试代码
seq_len = 4
causal_mask = create_causal_mask(seq_len)
print("Causal Mask:")
print(causal_mask)

Causal Mask:
tensor([[False,  True,  True,  True],
        [False, False,  True,  True],
        [False, False, False,  True],
        [False, False, False, False]])


In [23]:
import torch
import torch.nn.functional as F

def apply_causal_mask(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
    """
    在自注意力计算中应用因果掩码
    参数:
        query: 查询张量 shape (batch_size, seq_len, d_model)
        key: 键张量 shape (batch_size, seq_len, d_model)
        value: 值张量 shape (batch_size, seq_len, d_model)
    返回:
        attention_output: 注意力输出
    """
    # 计算注意力分数
    scores = torch.matmul(query, key.transpose(-2, -1))  # (batch_size, seq_len, seq_len)
    
    # 缩放注意力分数
    d_k = query.size(-1)
    scores = scores / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
    print("scores:", scores)
    
    # 创建并应用因果掩码
    seq_len = query.size(1)
    mask = create_causal_mask(seq_len)
    
    # 将掩码位置的值设置为一个很大的负数
    scores = scores.masked_fill(mask, float('-inf'))
    print("scores:", scores)
    
    # 应用softmax获取注意力权重
    attention_weights = F.softmax(scores, dim=-1)
    print("attention_weights:", attention_weights)
    
    # 计算输出
    attention_output = torch.matmul(attention_weights, value)
    
    return attention_output

# 测试代码
batch_size = 2
seq_len = 4
d_model = 8

# 创建示例输入
query = torch.randn(batch_size, seq_len, d_model)
key = torch.randn(batch_size, seq_len, d_model)
value = torch.randn(batch_size, seq_len, d_model)

# 应用因果掩码的注意力计算
output = apply_causal_mask(query, key, value)
print("输入形状:", query.shape)
print("输出形状:", output.shape)

scores: tensor([[[ 1.5958, -1.4635, -1.4237,  1.2843],
         [ 1.7220,  0.9283, -0.4925,  0.3239],
         [ 0.8589,  0.9786, -0.0241, -0.1395],
         [ 0.6889, -0.1886,  0.2087, -1.0112]],

        [[-2.2927,  0.8763, -1.5735, -0.2772],
         [-0.9213, -0.9578,  1.8836,  2.6790],
         [ 1.2034,  0.4874, -0.8374, -1.0190],
         [-0.1216,  0.7642, -1.0607, -1.5804]]])
scores: tensor([[[ 1.5958,    -inf,    -inf,    -inf],
         [ 1.7220,  0.9283,    -inf,    -inf],
         [ 0.8589,  0.9786, -0.0241,    -inf],
         [ 0.6889, -0.1886,  0.2087, -1.0112]],

        [[-2.2927,    -inf,    -inf,    -inf],
         [-0.9213, -0.9578,    -inf,    -inf],
         [ 1.2034,  0.4874, -0.8374,    -inf],
         [-0.1216,  0.7642, -1.0607, -1.5804]]])
attention_weights: tensor([[[1.0000, 0.0000, 0.0000, 0.0000],
         [0.6886, 0.3114, 0.0000, 0.0000],
         [0.3936, 0.4436, 0.1628, 0.0000],
         [0.4510, 0.1876, 0.2790, 0.0824]],

        [[1.0000, 0.0000, 0.000

In [26]:
import numpy as np
import torch
n = 3
mask = np.ones((n, n), dtype=bool)
mask = np.triu(mask,k=0)

print(mask)

[[ True  True  True]
 [False  True  True]
 [False False  True]]
