# Chunked Prefill & Flash Decoding计算演示

Author: kaiyuan

Email: kyxie@zju.edu.cn

## 1 Chunked Prefill计算

构造一个ChunkedPrefill运算与基础attention运算的对比

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, List


class CausalChunkedPrefill(nn.Module):
    """
    流式 + 因果的Chunked Prefill实现
    专为自回归LLM（如GPT、LLaMA）的推理优化
    """

    def __init__(self, d_model: int, n_heads: int, chunk_size: int = 512):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.chunk_size = chunk_size
        self.head_dim = d_model // n_heads

        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"

        # QKV投影层
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)

    def _split_heads(self, x: torch.Tensor) -> torch.Tensor:
        """将张量分割成多头"""
        batch_size, seq_len, _ = x.shape
        return x.view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)

    def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
        """将多头合并"""
        batch_size, n_heads, seq_len, head_dim = x.shape
        return x.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)

    def prefill_standard(self, x: torch.Tensor) -> torch.Tensor:
        """
        标准注意力（不分块）- 用于验证正确性
        因果注意力：每个位置只能看到之前的位置
        """
        batch_size, seq_len, _ = x.shape

        # 计算QKV
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)

        # 分割多头
        q = self._split_heads(q)  # [batch, n_heads, seq_len, head_dim]
        k = self._split_heads(k)
        v = self._split_heads(v)

        # 计算注意力分数
        scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)

        # 应用因果掩码（下三角矩阵）
        mask = torch.tril(torch.ones(seq_len, seq_len, device=x.device))
        mask = mask.view(1, 1, seq_len, seq_len)
        scores = scores.masked_fill(mask == 0, float('-inf'))

        # softmax
        attn_weights = F.softmax(scores, dim=-1)

        # 注意力输出
        attn_output = torch.matmul(attn_weights, v)

        # 合并多头
        output = self._merge_heads(attn_output)
        output = self.out_proj(output)

        return output

    def prefill_chunked(self, x: torch.Tensor) -> Tuple[torch.Tensor, Tuple[List[torch.Tensor], List[torch.Tensor]]]:
        """
        分块预填充（流式 + 因果）

        Args:
            x: 输入序列 [batch, seq_len, d_model]

        Returns:
            output: 注意力输出 [batch, seq_len, d_model]
            kv_cache: KV缓存 (K列表, V列表)
        """
        batch_size, seq_len, _ = x.shape

        # 计算总chunk数
        n_chunks = (seq_len + self.chunk_size - 1) // self.chunk_size

        # 初始化KV缓存（存储每个chunk的K和V）
        k_cache = []  # 每个元素: [batch, n_heads, chunk_size, head_dim]
        v_cache = []  # 每个元素: [batch, n_heads, chunk_size, head_dim]

        # 存储每个chunk的输出
        outputs = []

        print(f"分块预填充: 序列长度={seq_len}, 分块大小={self.chunk_size}, 分块数={n_chunks}")

        for chunk_idx in range(n_chunks):
            # 当前chunk的起始和结束位置
            start = chunk_idx * self.chunk_size
            end = min((chunk_idx + 1) * self.chunk_size, seq_len)
            chunk_len = end - start

            # 获取当前chunk
            chunk = x[:, start:end, :]

            # 计算当前chunk的QKV
            q = self.q_proj(chunk)
            k = self.k_proj(chunk)
            v = self.v_proj(chunk)

            # 分割多头
            q = self._split_heads(q)  # [batch, n_heads, chunk_len, head_dim]
            k = self._split_heads(k)
            v = self._split_heads(v)

            # 将当前chunk的K和V添加到缓存
            k_cache.append(k)
            v_cache.append(v)

            # 当前累计的KV总长度
            total_kv_len = sum(k.shape[2] for k in k_cache)

            # 拼接当前所有可用的K和V（因果：只能看到当前和之前的chunk）
            k_all = torch.cat(k_cache, dim=2)  # [batch, n_heads, total_kv_len, head_dim]
            v_all = torch.cat(v_cache, dim=2)

            # 计算注意力分数
            scores = torch.matmul(q, k_all.transpose(-2, -1)) / (self.head_dim ** 0.5)

            # 创建因果掩码
            # 注意：我们需要确保当前chunk内的Q也不能看到同一chunk内未来的K
            # 所以需要构建一个 [chunk_len, total_kv_len] 的掩码

            # 方法1：构建完整的掩码矩阵
            q_positions = torch.arange(chunk_len, device=x.device).unsqueeze(1) + start
            kv_positions = []
            for i, k_chunk in enumerate(k_cache):
                kv_start = i * self.chunk_size
                kv_len = k_chunk.shape[2]
                kv_positions.extend(range(kv_start, kv_start + kv_len))
            kv_positions = torch.tensor(kv_positions, device=x.device).unsqueeze(0)

            # Q位置只能看到小于等于它的KV位置
            mask = q_positions >= kv_positions  # [chunk_len, total_kv_len]
            mask = mask.view(1, 1, chunk_len, total_kv_len)

            # 应用掩码
            scores = scores.masked_fill(~mask, float('-inf'))

            # softmax
            attn_weights = F.softmax(scores, dim=-1)

            # 注意力输出
            attn_output = torch.matmul(attn_weights, v_all)

            # 合并多头
            output_chunk = self._merge_heads(attn_output)
            output_chunk = self.out_proj(output_chunk)

            outputs.append(output_chunk)

            print(f"  处理chunk {chunk_idx+1}/{n_chunks}: "
                  f"位置 {start}:{end}, "
                  f"KV缓存长度={total_kv_len}")

        # 拼接所有chunk的输出
        output = torch.cat(outputs, dim=1)

        return output, (k_cache, v_cache)

    def decode_step(self,
                   x: torch.Tensor,
                   kv_cache: Tuple[List[torch.Tensor], List[torch.Tensor]]
                   ) -> Tuple[torch.Tensor, Tuple[List[torch.Tensor], List[torch.Tensor]]]:
        """
        解码步骤：处理单个token（使用KV缓存）

        Args:
            x: 当前token [batch, 1, d_model]
            kv_cache: KV缓存 (K列表, V列表)

        Returns:
            output: 当前token的输出 [batch, 1, d_model]
            updated_kv_cache: 更新后的KV缓存
        """
        k_cache, v_cache = kv_cache

        # 计算当前token的QKV
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)

        # 分割多头
        q = self._split_heads(q)  # [batch, n_heads, 1, head_dim]
        k = self._split_heads(k)
        v = self._split_heads(v)

        # 添加到缓存
        k_cache.append(k)
        v_cache.append(v)

        # 拼接所有K和V
        k_all = torch.cat(k_cache, dim=2)
        v_all = torch.cat(v_cache, dim=2)

        # 计算注意力（因果掩码自动满足，因为只关注最后一个位置）
        scores = torch.matmul(q, k_all.transpose(-2, -1)) / (self.head_dim ** 0.5)

        # softmax
        attn_weights = F.softmax(scores, dim=-1)

        # 注意力输出
        attn_output = torch.matmul(attn_weights, v_all)

        # 合并多头
        output = self._merge_heads(attn_output)
        output = self.out_proj(output)

        return output, (k_cache, v_cache)

构造一个流式计算类：

In [2]:
class StreamingLLMAttention:
    """
    流式LLM注意力层
    """

    def __init__(self, d_model: int, n_heads: int, chunk_size: int = 512):
        self.attn = CausalChunkedPrefill(d_model, n_heads, chunk_size)
        self.chunk_size = chunk_size
        self.kv_cache = None

    def prefill(self, prompt: torch.Tensor) -> torch.Tensor:
        """
        预填充阶段：处理用户输入的prompt

        Args:
            prompt: 用户输入的prompt [batch, prompt_len, d_model]

        Returns:
            注意力输出
        """
        output, self.kv_cache = self.attn.prefill_chunked(prompt)
        return output

    def generate_token(self, token_emb: torch.Tensor) -> torch.Tensor:
        """
        生成一个token

        Args:
            token_emb: 当前token的embedding [batch, 1, d_model]

        Returns:
            当前token的输出
        """
        if self.kv_cache is None:
            raise ValueError("请先调用prefill初始化KV缓存")

        output, self.kv_cache = self.attn.decode_step(token_emb, self.kv_cache)
        return output

    def reset_cache(self):
        """重置KV缓存"""
        self.kv_cache = None

设计一个测试函数：

In [3]:
def test_causal_chunked_prefill():
    """测试流式 + 因果的分块预填充"""

    torch.manual_seed(42)

    print("=" * 70)
    print("流式 + 因果分块预填充测试")
    print("=" * 70)

    # 测试配置
    batch_size = 2
    seq_len = 9  # 测试用短序列
    d_model = 64
    n_heads = 4
    chunk_size = 3

    print(f"配置:")
    print(f"  batch_size={batch_size}, seq_len={seq_len}")
    print(f"  d_model={d_model}, n_heads={n_heads}")
    print(f"  chunk_size={chunk_size}")

    # 创建模型
    model = CausalChunkedPrefill(
        d_model=d_model,
        n_heads=n_heads,
        chunk_size=chunk_size
    )

    # 创建随机输入
    x = torch.randn(batch_size, seq_len, d_model)

    print(f"\n输入形状: {x.shape}")

    print("\n1. 计算标准注意力（不分块）...")
    with torch.no_grad():
        output_standard = model.prefill_standard(x)
    print(f"   标准注意力输出形状: {output_standard.shape}")

    print("\n2. 计算分块预填充...")
    with torch.no_grad():
        output_chunked, kv_cache = model.prefill_chunked(x)
    print(f"   分块预填充输出形状: {output_chunked.shape}")
    print("\n3. 比较两种方法的输出...")
    diff = torch.abs(output_standard - output_chunked)
    max_diff = diff.max().item()
    mean_diff = diff.mean().item()

    print(f"   最大差异: {max_diff:.10f}")
    print(f"   平均差异: {mean_diff:.10f}")

    tolerance = 1e-5
    if max_diff < tolerance:
        print(f"   ✓ 测试通过！差异在容忍范围内 (< {tolerance})")
    else:
        print(f"   ✗ 测试失败！差异超出容忍范围")
        print(f"\n   调试信息（第一个样本的前3个位置）:")
        for pos in range(min(3, seq_len)):
            print(f"   位置 {pos}:")
            print(f"     标准: {output_standard[0, pos, :5].detach().numpy().round(4)}")
            print(f"     分块: {output_chunked[0, pos, :5].detach().numpy().round(4)}")
            print(f"     差异: {diff[0, pos, :5].detach().numpy().round(8)}")

    print("\n4. 测试解码步骤（生成后续token）...")

    # 生成一个测试token
    test_token = torch.randn(batch_size, 1, d_model)

    # 使用KV缓存解码
    output_decode, updated_kv_cache = model.decode_step(test_token, kv_cache)

    print(f"   解码输出形状: {output_decode.shape}")
    print(f"   更新后KV缓存长度: {sum(k.shape[2] for k in updated_kv_cache[0])}")

    # 验证解码的正确性
    x_with_new = torch.cat([x, test_token], dim=1)

    # 用标准注意力计算完整结果
    output_full_new = model.prefill_standard(x_with_new)

    # 取最后一个位置的输出（对应新token）
    output_full_last = output_full_new[:, -1:, :]

    # 比较
    diff_decode = torch.abs(output_decode - output_full_last).max().item()
    print(f"   解码vs标准差异: {diff_decode:.10f}")

    if diff_decode < tolerance:
        print(f"   ✓ 解码步骤正确")
    else:
        print(f"   ✗ 解码步骤有误")

    return max_diff < tolerance and diff_decode < tolerance


def test_streaming_api():
    """测试流式API"""
    torch.manual_seed(42)

    print("\n" + "=" * 70)
    print("流式API测试")
    print("=" * 70)

    # 创建流式LLM注意力
    stream_attn = StreamingLLMAttention(
        d_model=64,
        n_heads=4,
        chunk_size=3
    )

    # 模拟一个prompt
    prompt_len = 9
    prompt = torch.randn(1, prompt_len, 64)

    print(f"1. 预填充阶段: 处理{prompt_len}个token的prompt")
    output = stream_attn.prefill(prompt)
    print(f"   输出形状: {output.shape}")

    print(f"\n2. 生成阶段: 生成3个token")
    for i in range(3):
        # 模拟一个token的embedding（实际中从embedding层获取）
        token_emb = torch.randn(1, 1, 64)

        output_token = stream_attn.generate_token(token_emb)
        print(f"   生成token {i+1}: 输出形状 {output_token.shape}")

    print(f"\n3. 重置缓存")
    stream_attn.reset_cache()
    print(f"   KV缓存已重置")


def benchmark_performance():
    """性能基准测试"""

    import time

    torch.manual_seed(42)

    print("\n" + "=" * 70)
    print("性能基准测试")
    print("=" * 70)

    # 测试长序列
    d_model = 1024
    n_heads = 16
    chunk_size = 512

    # 创建模型
    model = CausalChunkedPrefill(
        d_model=d_model,
        n_heads=n_heads,
        chunk_size=chunk_size
    )

    # 测试不同序列长度
    test_cases = [
        {"seq_len": 512, "desc": "短序列（一个chunk）"},
        {"seq_len": 2048, "desc": "中等序列（4个chunk）"},
        {"seq_len": 8192, "desc": "长序列（16个chunk）"},
    ]

    for test_case in test_cases:
        seq_len = test_case["seq_len"]

        print(f"\n测试: {test_case['desc']} (seq_len={seq_len})")

        x = torch.randn(1, seq_len, d_model)

        # 标准注意力（可能OOM）
        try:
            torch.cuda.empty_cache() if torch.cuda.is_available() else None

            start = time.time()
            with torch.no_grad():
                output_std = model.prefill_standard(x)
            time_std = time.time() - start

            mem_std = torch.cuda.max_memory_allocated() if torch.cuda.is_available() else 0

            print(f"  标准注意力: {time_std:.3f}s, "
                  f"内存: {mem_std/1024**2:.1f}MB" if torch.cuda.is_available() else f"{time_std:.3f}s")
        except RuntimeError as e:
            print(f"  标准注意力 OOM: {e}")
            time_std = float('inf')
            mem_std = float('inf')

        # 分块注意力
        try:
            torch.cuda.empty_cache() if torch.cuda.is_available() else None

            start = time.time()
            with torch.no_grad():
                output_chunk, _ = model.prefill_chunked(x)
            time_chunk = time.time() - start

            mem_chunk = torch.cuda.max_memory_allocated() if torch.cuda.is_available() else 0

            print(f"  分块注意力: {time_chunk:.3f}s, "
                  f"内存: {mem_chunk/1024**2:.1f}MB" if torch.cuda.is_available() else f"{time_chunk:.3f}s")

            if time_std != float('inf'):
                speedup = time_std / time_chunk if time_chunk > 0 else 0
                mem_reduction = (mem_std - mem_chunk) / mem_std * 100 if mem_std > 0 else 0
                print(f"  加速: {speedup:.1f}x, 内存减少: {mem_reduction:.1f}%")
        except RuntimeError as e:
            print(f"  分块注意力 OOM: {e}")


print("流式 + 因果分块预填充实现")

# 运行基本测试
test_passed = test_causal_chunked_prefill()

if test_passed:
    # 测试流式API
    test_streaming_api()

    # 性能测试
    if torch.cuda.is_available():
        benchmark_performance()
    else:
        print("\n注意: CUDA不可用，跳过性能测试")

流式 + 因果分块预填充实现
流式 + 因果分块预填充测试
配置:
  batch_size=2, seq_len=9
  d_model=64, n_heads=4
  chunk_size=3

输入形状: torch.Size([2, 9, 64])

1. 计算标准注意力（不分块）...
   标准注意力输出形状: torch.Size([2, 9, 64])

2. 计算分块预填充...
分块预填充: 序列长度=9, 分块大小=3, 分块数=3
  处理chunk 1/3: 位置 0:3, KV缓存长度=3
  处理chunk 2/3: 位置 3:6, KV缓存长度=6
  处理chunk 3/3: 位置 6:9, KV缓存长度=9
   分块预填充输出形状: torch.Size([2, 9, 64])

3. 比较两种方法的输出...
   最大差异: 0.0000001192
   平均差异: 0.0000000094
   ✓ 测试通过！差异在容忍范围内 (< 1e-05)

4. 测试解码步骤（生成后续token）...
   解码输出形状: torch.Size([2, 1, 64])
   更新后KV缓存长度: 10
   解码vs标准差异: 0.0000000745
   ✓ 解码步骤正确

流式API测试
1. 预填充阶段: 处理9个token的prompt
分块预填充: 序列长度=9, 分块大小=3, 分块数=3
  处理chunk 1/3: 位置 0:3, KV缓存长度=3
  处理chunk 2/3: 位置 3:6, KV缓存长度=6
  处理chunk 3/3: 位置 6:9, KV缓存长度=9
   输出形状: torch.Size([1, 9, 64])

2. 生成阶段: 生成3个token
   生成token 1: 输出形状 torch.Size([1, 1, 64])
   生成token 2: 输出形状 torch.Size([1, 1, 64])
   生成token 3: 输出形状 torch.Size([1, 1, 64])

3. 重置缓存
   KV缓存已重置

注意: CUDA不可用，跳过性能测试


## 2 Flash Decoding计算演示

## 2.1 方式一：保存max和block_sum_exp值

In [4]:
import torch
import torch.nn.functional as F
import math
import time


class FlashDecodingDemo:
    """Flash-Decoding注意力计算演示"""

    def __init__(self, d_model: int = 64, num_heads: int = 8):
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads

    def traditional_attention(self, q, k, v):
        """传统连续注意力计算"""
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        attention_weights = F.softmax(scores, dim=-1)
        output = torch.matmul(attention_weights, v)
        return output, attention_weights

    def flash_decoding_attention(self, q, k, v, block_size=32, tiling_mode='distributed'):
      if tiling_mode == 'distributed':
        return self.flash_decoding_distributed_tiling(q, k, v,
                                  tile_size_kv=block_size)
      else:
        return self.flash_decoding_without_cp(q, k, v, block_size)

    def flash_decoding_without_cp(self, q, k, v, block_size=32):
        """
        分块的FA

        """
        batch_size, num_heads, seq_len_q, _ = q.shape
        seq_len_kv = k.shape[2]
        num_blocks = (seq_len_kv + block_size - 1) // block_size

        # 初始化累积变量
        # 累积的加权和
        numerator = torch.zeros(batch_size, num_heads, seq_len_q, self.head_dim,
                                device=q.device, dtype=q.dtype)
        # 累积的归一化因子
        d_prime = torch.zeros(batch_size, num_heads, seq_len_q, 1,
                                  device=q.device, dtype=q.dtype)

        # 用于数值稳定性的全局最大值（初始设为很小的数）
        global_max = torch.full((batch_size, num_heads, seq_len_q, 1),
                                -float('inf'),
                                device=q.device, dtype=q.dtype)

        # 分块处理
        for block_idx in range(num_blocks):
            start_idx = block_idx * block_size
            end_idx = min(start_idx + block_size, seq_len_kv)

            k_block = k[:, :, start_idx:end_idx, :]
            v_block = v[:, :, start_idx:end_idx, :]

            # 计算当前块的注意力分数
            scores_block = torch.matmul(q, k_block.transpose(-2, -1)) / math.sqrt(self.head_dim)

            # 当前块的最大值
            block_max = scores_block.max(dim=-1, keepdim=True).values

            # 更新全局最大值
            # 我们需要比较每个位置（每个query）在所有块中的最大值
            new_global_max = torch.maximum(global_max, block_max)

            # 调整之前累积的权重（基于新的全局最大值）
            # 当全局最大值更新时，需要重新调整之前累积的权重
            if block_idx > 0:
                # 将之前累积的权重调整到新的尺度
                adjustment_factor = torch.exp(global_max - new_global_max)
                numerator = numerator * adjustment_factor
                d_prime = d_prime * adjustment_factor

            # 更新全局最大值
            global_max = new_global_max

            # 计算当前块的指数权重（减去全局最大值以保持数值稳定）
            exp_scores = torch.exp(scores_block - global_max)
            block_sum_exp = exp_scores.sum(dim=-1, keepdim=True)

            # 累积加权和
            numerator = numerator + torch.matmul(exp_scores, v_block)
            d_prime = d_prime + block_sum_exp

        # 最终归一化
        final_output = numerator / d_prime

        # 为了验证，也计算完整的注意力权重
        full_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        full_attention_weights = F.softmax(full_scores, dim=-1)

        return final_output, full_attention_weights

    def flash_decoding_distributed_tiling(self, q, k, v,
                          tile_size_kv: int = 256,
                          num_streams: int = 5):
        """
        使用数组明确模拟多个计算流（stream）的并行
        每个流有自己独立的累加器数组
        """
        batch_size, num_heads, seq_len_q, head_dim = q.shape
        seq_len_kv = k.shape[2]
        num_tiles = (seq_len_kv + tile_size_kv - 1) // tile_size_kv

        print(f"\n分布式数组实现: {num_streams}个计算流")
        print(f"每个流有自己的O、M、L数组")

        # 创建流数组：每个流有独立的(O, M, L)
        stream_O = []  # 加权和数组
        stream_M = []  # 最大值数组
        stream_L = []  # exp和数组

        for stream_id in range(num_streams):
            # 每个流初始化自己的累加器
            O_stream = torch.zeros_like(q)
            M_stream = torch.full((batch_size, num_heads, seq_len_q, 1),
                                -float('inf'), device=q.device, dtype=q.dtype)
            L_stream = torch.zeros_like(M_stream)

            stream_O.append(O_stream)
            stream_M.append(M_stream)
            stream_L.append(L_stream)

        # 模拟流并行处理tile
        print(f"并行处理{num_tiles}个tile...")

        for tile_idx in range(num_tiles):
            # 确定处理这个tile的流
            stream_id = tile_idx % num_streams

            # 获取当前tile
            start_idx = tile_idx * tile_size_kv
            end_idx = min(start_idx + tile_size_kv, seq_len_kv)

            k_tile = k[:, :, start_idx:end_idx, :]
            v_tile = v[:, :, start_idx:end_idx, :]

            # 当前流处理（只能访问自己的数组）
            O_curr = stream_O[stream_id]
            M_curr = stream_M[stream_id]
            L_curr = stream_L[stream_id]

            # 计算当前tile
            S_tile = torch.matmul(q, k_tile.transpose(-2, -1)) / math.sqrt(head_dim)
            m_tile = S_tile.max(dim=-1, keepdim=True).values

            # 更新当前流的统计量
            new_M = torch.maximum(M_curr, m_tile)

            if not torch.allclose(M_curr, new_M):
                scale = torch.exp(M_curr - new_M)
                O_curr = O_curr * scale
                L_curr = L_curr * scale

            exp_tile = torch.exp(S_tile - new_M)
            l_tile = exp_tile.sum(dim=-1, keepdim=True)

            # 更新当前流的数组
            stream_O[stream_id] = O_curr + torch.matmul(exp_tile, v_tile)
            stream_L[stream_id] = L_curr + l_tile
            stream_M[stream_id] = new_M


        # 归约所有流的结果
        final_output = self.reduce_stream_arrays(stream_O, stream_M, stream_L)

        # 为了验证，也计算完整的注意力权重
        full_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        full_attention_weights = F.softmax(full_scores, dim=-1)

        return final_output, full_attention_weights

    def reduce_stream_arrays(self, stream_O, stream_M, stream_L):
        """归约多个流的数组结果"""
        num_streams = len(stream_O)
        if num_streams == 0:
            return torch.zeros_like(stream_O[0])

        # 使用树形归约算法
        # 第一轮：相邻流两两归约
        current_O = stream_O.copy()
        current_M = stream_M.copy()
        current_L = stream_L.copy()

        remaining = num_streams
        step = 1

        while remaining > 1:
            next_O = []
            next_M = []
            next_L = []

            # 每两个流归约为一个
            for i in range(0, remaining, 2):
                if i + 1 < remaining:
                    # 归约流i和流i+1
                    O1, M1, L1 = current_O[i], current_M[i], current_L[i]
                    O2, M2, L2 = current_O[i+1], current_M[i+1], current_L[i+1]

                    # 合并
                    new_M = torch.maximum(M1, M2)

                    # 调整第一个流
                    if not torch.allclose(M1, new_M):
                        scale1 = torch.exp(M1 - new_M)
                        O1 = O1 * scale1
                        L1 = L1 * scale1

                    # 调整第二个流
                    scale2 = torch.exp(M2 - new_M)
                    O2 = O2 * scale2
                    L2 = L2 * scale2

                    # 合并结果
                    merged_O = O1 + O2
                    merged_L = L1 + L2

                    next_O.append(merged_O)
                    next_M.append(new_M)
                    next_L.append(merged_L)

                else:
                    # 奇数个流时，最后一个流直接进入下一轮
                    next_O.append(current_O[i])
                    next_M.append(current_M[i])
                    next_L.append(current_L[i])

            current_O = next_O
            current_M = next_M
            current_L = next_L
            remaining = len(current_O)
            step += 1

        # 最终归一化
        final_output = current_O[0] / current_L[0]
        print(f"归约完成，最终输出形状: {final_output.shape}")
        return final_output

    def flash_decoding_attention_simple(self, q, k, v, block_size=32):
        """
        简化版本Flash-Decoding实现，包含两个循环。
        需要保存每个块的max值、block_sum_exp值。
        特点：理解直观。
        """
        batch_size, num_heads, seq_len_q, _ = q.shape
        seq_len_kv = k.shape[2]
        num_blocks = (seq_len_kv + block_size - 1) // block_size

        # 存储每个块的中间结果
        block_outputs = []
        block_max_vals = []
        block_sum_exps = []

        # 第一步：计算每个块的局部结果
        for block_idx in range(num_blocks):
            start_idx = block_idx * block_size
            end_idx = min(start_idx + block_size, seq_len_kv)

            k_block = k[:, :, start_idx:end_idx, :]
            v_block = v[:, :, start_idx:end_idx, :]

            # 计算当前块注意力分数
            scores_block = torch.matmul(q, k_block.transpose(-2, -1)) / math.sqrt(self.head_dim)
            block_max = scores_block.max(dim=-1, keepdim=True).values
            exp_scores = torch.exp(scores_block - block_max)
            block_sum_exp = exp_scores.sum(dim=-1, keepdim=True)

            # 存储中间结果
            block_outputs.append(torch.matmul(exp_scores, v_block))
            block_max_vals.append(block_max)
            block_sum_exps.append(block_sum_exp)

        # 第二步：合并所有块的结果
        # 找到全局最大值
        all_max_vals = torch.stack(block_max_vals, dim=0)  # [num_blocks, ...]
        global_max = all_max_vals.max(dim=0).values  # 在每个query位置取最大值

        # 合并归一化因子
        total_sum_exp = torch.zeros_like(block_sum_exps[0])
        for i in range(num_blocks):
            total_sum_exp += block_sum_exps[i] * torch.exp(block_max_vals[i] - global_max)

        # 合并输出
        final_output = torch.zeros_like(block_outputs[0])
        for i in range(num_blocks):
            # 将每个块的贡献调整到全局尺度
            weight = torch.exp(block_max_vals[i] - global_max)
            final_output += block_outputs[i] * weight

        # 最终归一化
        final_output = final_output / total_sum_exp

        # 计算完整注意力权重用于验证
        full_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        full_attention_weights = F.softmax(full_scores, dim=-1)

        return final_output, full_attention_weights

    def verify_with_tolerance(self, batch_size=2, seq_len_q=1, seq_len_kv=1024):
        """更严格的验证，包含容差检查"""

        # 生成随机测试数据
        torch.manual_seed(42)  # 固定随机种子以便复现
        q = torch.randn(batch_size, self.num_heads, seq_len_q, self.head_dim)
        k = torch.randn(batch_size, self.num_heads, seq_len_kv, self.head_dim)
        v = torch.randn(batch_size, self.num_heads, seq_len_kv, self.head_dim)

        print("=" * 70)
        print("Flash-Decoding 正确性验证")
        print("=" * 70)

        # 传统方法计算
        traditional_output, traditional_weights = self.traditional_attention(q, k, v)

        # 尝试不同block_size
        block_sizes = [16, 32, 64, 128, 256]

        for block_size in block_sizes:
            print(f"\n使用block_size={block_size}:")

            # 一般方法
            flash_output1, flash_weights1 = self.flash_decoding_attention(
                q, k, v, block_size
            )

            # 简化方法
            flash_output2, flash_weights2 = self.flash_decoding_attention_simple(
                q, k, v, block_size
            )

            # 比较结果
            diff1 = torch.abs(traditional_output - flash_output1).max().item()
            diff2 = torch.abs(traditional_output - flash_output2).max().item()

            # 相对误差
            rel_error1 = diff1 / (torch.abs(traditional_output).max().item() + 1e-10)
            rel_error2 = diff2 / (torch.abs(traditional_output).max().item() + 1e-10)

            # 检查是否在容差范围内
            tolerance = 1e-4
            is_correct1 = diff1 < tolerance
            is_correct2 = diff2 < tolerance

            print(f"  一般方法 - 最大绝对误差: {diff1:.2e}, 相对误差: {rel_error1:.2e}, 正确: {is_correct1}")
            print(f"  简化方法 - 最大绝对误差: {diff2:.2e}, 相对误差: {rel_error2:.2e}, 正确: {is_correct2}")

            # 如果两种方法都正确，还可以比较它们之间的一致性
            if is_correct1 and is_correct2:
                method_diff = torch.abs(flash_output1 - flash_output2).max().item()
                print(f"  两种方法间差异: {method_diff:.2e}")

        return True

    def analyze_numerical_stability(self):
        """数值稳定性分析"""

        print("\n" + "=" * 70)
        print("数值稳定性分析")
        print("=" * 70)

        # 测试不同范围的数值
        test_cases = [
            ("小数值范围", (-1.0, 1.0)),
            ("中等数值范围", (-10.0, 10.0)),
            ("大数值范围", (-50.0, 50.0)),
        ]

        for name, (min_val, max_val) in test_cases:
            print(f"\n{name} [{min_val}, {max_val}]:")

            # 生成特定范围的测试数据
            q = torch.rand(1, self.num_heads, 1, self.head_dim) * (max_val - min_val) + min_val
            k = torch.rand(1, self.num_heads, 1024, self.head_dim) * (max_val - min_val) + min_val
            v = torch.rand(1, self.num_heads, 1024, self.head_dim) * (max_val - min_val) + min_val

            # 传统方法
            traditional_output, _ = self.traditional_attention(q, k, v)

            # Flash-Decoding方法
            flash_output, _ = self.flash_decoding_attention(q, k, v, block_size=64)

            # 计算误差
            diff = torch.abs(traditional_output - flash_output).max().item()

            # 检查是否出现NaN或Inf
            has_nan = torch.isnan(flash_output).any().item()
            has_inf = torch.isinf(flash_output).any().item()

            print(f"  最大绝对误差: {diff:.2e}")
            print(f"  包含NaN: {has_nan}, 包含Inf: {has_inf}")

        return True


In [5]:
# 创建演示实例
demo = FlashDecodingDemo(d_model=512, num_heads=8)

# 1. 验证正确性（使用更严格的验证）
demo.verify_with_tolerance(
    batch_size=2,
    seq_len_q=1,
    seq_len_kv=1024
)

# 2. 数值稳定性分析
demo.analyze_numerical_stability()

# 3. 性能对比演示
print("\n" + "=" * 70)
print("性能对比演示（小batch_size，长序列）")
print("=" * 70)

# 模拟长序列推理场景
seq_lengths = [1024, 4096, 16384, 32768]

for seq_len in seq_lengths:
    print(f"\n序列长度: {seq_len}")

    # 生成测试数据
    q = torch.randn(1, 8, 1, 64)  # batch=1，单token查询
    k = torch.randn(1, 8, seq_len, 64)
    v = torch.randn(1, 8, seq_len, 64)

    # 传统方法时间
    start = time.time()
    traditional_output, _ = demo.traditional_attention(q, k, v)
    torch.cuda.synchronize() if torch.cuda.is_available() else None
    traditional_time = (time.time() - start) * 1000

    # Flash-Decoding时间
    start = time.time()
    flash_output, _ = demo.flash_decoding_attention(q, k, v, block_size=256)
    torch.cuda.synchronize() if torch.cuda.is_available() else None
    flash_time = (time.time() - start) * 1000

    # 验证一致性
    diff = torch.abs(traditional_output - flash_output).max().item()

    print(f"  传统方法: {traditional_time:.2f}ms")
    print(f"  Flash-Decoding: {flash_time:.2f}ms")
    print(f"  加速比: {traditional_time / flash_time:.2f}x")
    print(f"  输出差异: {diff:.2e}")
    print(f"  结果一致: {diff < 1e-4}")

Flash-Decoding 正确性验证

使用block_size=16:

分布式数组实现: 5个计算流
每个流有自己的O、M、L数组
并行处理64个tile...
归约完成，最终输出形状: torch.Size([2, 8, 1, 64])
  一般方法 - 最大绝对误差: 1.27e-07, 相对误差: 8.22e-07, 正确: True
  简化方法 - 最大绝对误差: 1.49e-07, 相对误差: 9.67e-07, 正确: True
  两种方法间差异: 6.71e-08

使用block_size=32:

分布式数组实现: 5个计算流
每个流有自己的O、M、L数组
并行处理32个tile...
归约完成，最终输出形状: torch.Size([2, 8, 1, 64])
  一般方法 - 最大绝对误差: 1.42e-07, 相对误差: 9.19e-07, 正确: True
  简化方法 - 最大绝对误差: 1.49e-07, 相对误差: 9.67e-07, 正确: True
  两种方法间差异: 4.47e-08

使用block_size=64:

分布式数组实现: 5个计算流
每个流有自己的O、M、L数组
并行处理16个tile...
归约完成，最终输出形状: torch.Size([2, 8, 1, 64])
  一般方法 - 最大绝对误差: 1.27e-07, 相对误差: 8.22e-07, 正确: True
  简化方法 - 最大绝对误差: 1.56e-07, 相对误差: 1.02e-06, 正确: True
  两种方法间差异: 4.47e-08

使用block_size=128:

分布式数组实现: 5个计算流
每个流有自己的O、M、L数组
并行处理8个tile...
归约完成，最终输出形状: torch.Size([2, 8, 1, 64])
  一般方法 - 最大绝对误差: 1.27e-07, 相对误差: 8.22e-07, 正确: True
  简化方法 - 最大绝对误差: 1.27e-07, 相对误差: 8.22e-07, 正确: True
  两种方法间差异: 4.47e-08

使用block_size=256:

分布式数组实现: 5个计算流
每个流有自己的O、M、L数组
并行处理4个tile...
归约完成，最终

## 2.2 方式二：保存log-sum-exp

In [6]:
import torch
import torch.nn.functional as F
import math

class FinalFlashDecodingTiling:
    """
    最终版Flash-Decoding Tiling实现
    仅存储O和S，使用两步合并算法
    """

    def __init__(self, d_model: int = 512, num_heads: int = 8):
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads

    def traditional_attention(self, q, k, v):
        """基准：传统连续注意力计算"""
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        attention_weights = F.softmax(scores, dim=-1)
        output = torch.matmul(attention_weights, v)
        return output

    def compute_stream_output(self, q, k_tile, v_tile):
        """
        计算单个tile的流输出
        返回: (O_i, S_i) 其中 S_i = m_i + log(l_i)
        """
        # 计算当前tile的注意力分数
        S_tile = torch.matmul(q, k_tile.transpose(-2, -1)) / math.sqrt(self.head_dim)

        # 计算m_i和l_i
        m_i = S_tile.max(dim=-1, keepdim=True).values
        exp_tile = torch.exp(S_tile - m_i)
        l_i = exp_tile.sum(dim=-1, keepdim=True)

        # 计算加权和 O_i（已经是归一化的）
        O_i = torch.matmul(exp_tile, v_tile) / l_i

        # 计算 S_i = m_i + log(l_i)
        log_l_i = torch.log(l_i + 1e-12)
        S_i = m_i + log_l_i

        return O_i, S_i

    def merge_streams_two_step(self, streams_data):
        """
        两步合并算法：
        1. 迭代计算全局 S_global
        2. 用 S_global 修正每个流的输出贡献
        """
        if not streams_data:
            return None

        # 提取所有流的S_i
        S_list = [S_i for _, S_i in streams_data]

        # 步骤1: 迭代计算全局 S_global (S_lst)
        S_global = S_list[0].clone()

        for i in range(1, len(S_list)):
            S_i = S_list[i]
            S_max = torch.maximum(S_global, S_i)
            S_min = torch.minimum(S_global, S_i)
            # 使用log(1+exp(x))的稳定计算
            log_term = torch.log1p(torch.exp(S_min - S_max))
            S_global = S_max + log_term

        # 步骤2: 修正每个流的输出贡献
        O_global = torch.zeros_like(streams_data[0][0])

        for O_i, S_i in streams_data:
            # 计算该流对全局的贡献权重
            weight = torch.exp(S_i - S_global)
            # 累加加权贡献
            O_global += O_i * weight

        return O_global

    def flash_decoding_with_lse(self, q, k, v,
                            tile_size_kv: int = 256,
                            num_streams: int = 4):
        """
        Flash-Decoding 仅存储O和S
        """
        batch_size, num_heads, seq_len_q, head_dim = q.shape
        seq_len_kv = k.shape[2]
        num_tiles = (seq_len_kv + tile_size_kv - 1) // tile_size_kv

        print(f"使用两步合并算法Flash-Decoding: {num_streams}个流")

        # 初始化流数组
        streams_data = []

        for stream_id in range(num_streams):
            # 每个流存储(O_i, S_i)
            O_stream = torch.zeros_like(q)
            S_stream = torch.full((batch_size, num_heads, seq_len_q, 1),
                                -float('inf'), device=q.device, dtype=q.dtype)
            streams_data.append((O_stream, S_stream))

        # 处理每个tile
        print(f"处理{num_tiles}个tile...")

        for tile_idx in range(num_tiles):
            stream_id = tile_idx % num_streams

            start_idx = tile_idx * tile_size_kv
            end_idx = min(start_idx + tile_size_kv, seq_len_kv)

            k_tile = k[:, :, start_idx:end_idx, :]
            v_tile = v[:, :, start_idx:end_idx, :]

            # 计算当前tile的输出
            O_i, S_i = self.compute_stream_output(q, k_tile, v_tile)

            # 获取当前流的累加器
            O_acc, S_acc = streams_data[stream_id]

            # 合并当前tile结果到流累加器
            if torch.all(S_acc == -float('inf')):
                streams_data[stream_id] = (O_i, S_i)
            else:
                # 使用两步法合并当前tile到流累加器
                # 先计算合并后的S
                S_max = torch.maximum(S_acc, S_i)
                S_min = torch.minimum(S_acc, S_i)
                log_term = torch.log1p(torch.exp(S_min - S_max))
                S_merged = S_max + log_term

                # 修正两个部分的贡献
                weight_acc = torch.exp(S_acc - S_merged)
                weight_i = torch.exp(S_i - S_merged)
                O_merged = O_acc * weight_acc + O_i * weight_i

                streams_data[stream_id] = (O_merged, S_merged)

        print(f"所有tile处理完成，开始归约所有流...")
        # 归约所有流的结果
        O_final = self.merge_streams_two_step(streams_data)

        return O_final

    def verify_correctness(self, seq_len_kv: int = 2048):
        """验证实现的正确性"""

        torch.manual_seed(42)
        batch_size = 2
        seq_len_q = 1

        q = torch.randn(batch_size, self.num_heads, seq_len_q, self.head_dim)
        k = torch.randn(batch_size, self.num_heads, seq_len_kv, self.head_dim)
        v = torch.randn(batch_size, self.num_heads, seq_len_kv, self.head_dim)

        print("=" * 80)
        print("基于lse的Flash-Decoding")
        print("=" * 80)

        # 基准测试：传统方法
        print(f"\n1. 传统注意力计算...")
        baseline = self.traditional_attention(q, k, v)

        # 基于lse的Flash-Decoding
        print(f"\n2. 基于lse的Flash-Decoding...")
        output = self.flash_decoding_with_lse(q, k, v)

        # 验证正确性
        diff = torch.abs(baseline - output).max().item()
        rel_error = diff / torch.abs(baseline).max().item()

        print(f"\n验证结果:")
        print(f"  最大绝对误差: {diff:.2e}")
        print(f"  相对误差: {rel_error:.2e}")
        print(f"  合并算法是否正确: {diff < 1e-4}")

        # 数学正确性验证
        print(f"\n3. 数学正确性验证（小规模测试）...")

        # 创建一个小测试
        torch.manual_seed(123)
        q_test = torch.randn(1, 2, 1, 4)
        k_test = torch.randn(1, 2, 8, 4)
        v_test = torch.randn(1, 2, 8, 4)

        baseline_test = self.traditional_attention(q_test, k_test, v_test)
        output_test = self.flash_decoding_with_lse(q_test, k_test, v_test,
                                               tile_size_kv=4, num_streams=2)

        diff_test = torch.abs(baseline_test - output_test).max().item()
        print(f"  小规模测试最大绝对误差: {diff_test:.2e}")
        print(f"  小规模测试是否正确: {diff_test < 1e-4}")

        return {
            'baseline': baseline,
            'output': output,
            'error': diff,
            'correct': diff < 1e-4
        }


if __name__ == "__main__":
    demo = FinalFlashDecodingTiling(d_model=512, num_heads=8)

    # 验证最终实现的正确性
    print("基于lse的Flash-Decoding验证")
    results = demo.verify_correctness(seq_len_kv=2048)

    if results['correct']:
        print("\n✅ 算法现在可以正确合并结果。")
    else:
        print(f"\n❌ 仍然存在问题，误差: {results['error']:.2e}")

    # 性能测试
    print("\n" + "=" * 80)
    print("性能测试")
    print("=" * 80)

    import time
    torch.manual_seed(42)
    batch_size = 2
    seq_len_q = 1
    seq_len_kv = 8192

    q = torch.randn(batch_size, 8, seq_len_q, 64)
    k = torch.randn(batch_size, 8, seq_len_kv, 64)
    v = torch.randn(batch_size, 8, seq_len_kv, 64)

    # 传统方法
    start = time.time()
    baseline = demo.traditional_attention(q, k, v)
    torch.cuda.synchronize() if torch.cuda.is_available() else None
    trad_time = time.time() - start

    # 优化方法
    start = time.time()
    output = demo.flash_decoding_with_lse(q, k, v, tile_size_kv=256)
    torch.cuda.synchronize() if torch.cuda.is_available() else None
    opt_time = time.time() - start

    diff = torch.abs(baseline - output).max().item()

    print(f"序列长度: {seq_len_kv}")
    print(f"传统方法时间: {trad_time:.4f}s")
    print(f"优化方法时间: {opt_time:.4f}s")
    print(f"加速比: {trad_time/opt_time:.2f}x")
    print(f"误差: {diff:.2e}")
    print(f"是否一致: {diff < 1e-4}")

基于lse的Flash-Decoding验证
基于lse的Flash-Decoding

1. 传统注意力计算...

2. 基于lse的Flash-Decoding...
使用两步合并算法Flash-Decoding: 4个流
处理8个tile...
所有tile处理完成，开始归约所有流...

验证结果:
  最大绝对误差: 2.53e-07
  相对误差: 1.68e-06
  合并算法是否正确: True

3. 数学正确性验证（小规模测试）...
使用两步合并算法Flash-Decoding: 2个流
处理2个tile...
所有tile处理完成，开始归约所有流...
  小规模测试最大绝对误差: 1.19e-07
  小规模测试是否正确: True

✅ 算法现在可以正确合并结果。

性能测试
使用两步合并算法Flash-Decoding: 4个流
处理32个tile...
所有tile处理完成，开始归约所有流...
序列长度: 8192
传统方法时间: 0.0163s
优化方法时间: 0.0453s
加速比: 0.36x
误差: 1.60e-07
是否一致: True


## 2.3 方式二的等价性证明

In [7]:
def mathematical_equivalence_proof():
    """严格证明两步合并算法与传统方法的等价性"""

    print("\n" + "=" * 80)
    print("数学等价性证明")
    print("=" * 80)

    print("传统注意力计算：")
    print("  设全局有N个注意力分数")
    print("  全局最大值 M = max(score_j), j=1..N")
    print("  全局指数和 L = Σ_j exp(score_j - M)")
    print("  注意力输出 = Σ_j [exp(score_j - M) * v_j] / L")
    print()

    print("Flash-Decoding分块计算：")
    print("  将N个分数分成K个块，每个块i有：")
    print("    m_i = 块内最大值")
    print("    l_i = Σ_{j∈块i} exp(score_j - m_i)")
    print("    O_i = Σ_{j∈块i} [exp(score_j - m_i) * v_j] / l_i")
    print("    S_i = m_i + log(l_i)")
    print()

    print("传统合并算法：")
    print("  1. 找到全局最大值 M_global = max(m_i)")
    print("  2. 调整每个块的贡献：")
    print("     调整后l_i' = l_i × exp(m_i - M_global)")
    print("     调整后O_i' = O_i × exp(m_i - M_global)")
    print("  3. 合并：")
    print("     L_global = Σ_i l_i'")
    print("     O_global = Σ_i O_i' / L_global")
    print()

    print("两步合并算法：")
    print("  1. 计算 S_global = log(Σ_i exp(S_i))")
    print("  2. 计算 O_global = Σ_i [O_i × exp(S_i - S_global)]")
    print()

    print("证明等价性：")
    print("  步骤1：证明 exp(S_i) = l_i × exp(m_i)")
    print("     因为 S_i = m_i + log(l_i)")
    print("     所以 exp(S_i) = exp(m_i + log(l_i)) = l_i × exp(m_i)")
    print()

    print("  步骤2：证明 exp(S_global) = Σ_i [l_i × exp(m_i)]")
    print("     因为 S_global = log(Σ_i exp(S_i))")
    print("     所以 exp(S_global) = Σ_i exp(S_i) = Σ_i [l_i × exp(m_i)]")
    print()

    print("  步骤3：证明 exp(S_i - S_global) = [l_i × exp(m_i)] / Σ_j [l_j × exp(m_j)]")
    print("     exp(S_i - S_global) = exp(S_i) / exp(S_global)")
    print("                        = [l_i × exp(m_i)] / Σ_j [l_j × exp(m_j)]")
    print()

    print("  步骤4：证明 O_i × exp(S_i - S_global) = [O_i × l_i × exp(m_i)] / Σ_j [l_j × exp(m_j)]")
    print("     这显然成立")
    print()

    print("  步骤5：证明 Σ_i [O_i × exp(S_i - S_global)] 等于传统合并结果")
    print("     传统合并：O_global = Σ_i [O_i × l_i × exp(m_i - M_global)] / Σ_i [l_i × exp(m_i - M_global)]")
    print("                     = Σ_i [O_i × l_i × exp(m_i)] / Σ_i [l_i × exp(m_i)]  （乘以exp(M_global)）")
    print("                     = Σ_i [O_i × exp(S_i)] / Σ_i [exp(S_i)]")
    print("                     = Σ_i [O_i × exp(S_i - S_global)]  （由步骤3）")
    print()

mathematical_equivalence_proof()


数学等价性证明
传统注意力计算：
  设全局有N个注意力分数
  全局最大值 M = max(score_j), j=1..N
  全局指数和 L = Σ_j exp(score_j - M)
  注意力输出 = Σ_j [exp(score_j - M) * v_j] / L

Flash-Decoding分块计算：
  将N个分数分成K个块，每个块i有：
    m_i = 块内最大值
    l_i = Σ_{j∈块i} exp(score_j - m_i)
    O_i = Σ_{j∈块i} [exp(score_j - m_i) * v_j] / l_i
    S_i = m_i + log(l_i)

传统合并算法：
  1. 找到全局最大值 M_global = max(m_i)
  2. 调整每个块的贡献：
     调整后l_i' = l_i × exp(m_i - M_global)
     调整后O_i' = O_i × exp(m_i - M_global)
  3. 合并：
     L_global = Σ_i l_i'
     O_global = Σ_i O_i' / L_global

两步合并算法：
  1. 计算 S_global = log(Σ_i exp(S_i))
  2. 计算 O_global = Σ_i [O_i × exp(S_i - S_global)]

证明等价性：
  步骤1：证明 exp(S_i) = l_i × exp(m_i)
     因为 S_i = m_i + log(l_i)
     所以 exp(S_i) = exp(m_i + log(l_i)) = l_i × exp(m_i)

  步骤2：证明 exp(S_global) = Σ_i [l_i × exp(m_i)]
     因为 S_global = log(Σ_i exp(S_i))
     所以 exp(S_global) = Σ_i exp(S_i) = Σ_i [l_i × exp(m_i)]

  步骤3：证明 exp(S_i - S_global) = [l_i × exp(m_i)] / Σ_j [l_j × exp(m_j)]
     exp(S_i - S_global) = exp(S_i) /