In [1]:
from IPython.display import Image

- 优化长序列（long sequence，1M context window）的问题；
    - DP, TP, PP & SP
    - 长序列拆分到不同的设备上计算，每个设备处理 sub seq；
- https://arxiv.org/pdf/2105.13120
    - Sequence Parallelism: Long Sequence Training from System Perspective
- https://arxiv.org/pdf/2309.14509
    - DeepSpeed Ulysses: System Optimizations for Enabling Training of Extreme Long Sequence Transformer Models

### Ring Attention

- Ring-AllReduce：通信换内存
    - 序列 split/shard 到多张卡上，即每张卡只保存一个 sub seq；
    - (Ring)QK & (Ring)AV
        - 每个 device sub seq 的 Query 需要跟其他 devices 上的所有的 Key 做计算；
$$\text{Attention}(Q, K, V) = \underbrace{ \text{softmax}\left( \frac{QK^{\top}}{\sqrt{d_k}} \right) }_{\mathbf{A}} V$$

- N 个 devices，N-1 次 iter，每个 device 都有完整的 QK^T 的结果

$$
\underset{\substack{\uparrow \\ (b, n, d_v)}}{\text{Attention}(Q, K, V)} = \underbrace{\text{softmax} \left( \frac{\overbrace{\underset{\substack{\uparrow \\ (b, n, d_k)}}{Q} \cdot \underset{\substack{\uparrow \\ (b, d_k, n)}}{K^T}}^{\text{Scores Dim: }(b, n, n)}}{\underset{\substack{\uparrow \\ \text{scalar}}}{\sqrt{d_k}}} \right)}_{\text{Weights Dim: }(b, n, n)} \cdot \underset{\substack{\uparrow \\ (b, n, d_v)}}{V}
$$

In [3]:
Image(url='./imgs/ring-attn.png', width=500)

### DeepSpeed Ulysses

- Ulysses：尤利西斯（a very long novel)；
- all-to-all communication collective
    - DeepSpeed-Ulysses partitions individual samples along the sequence dimension among participating GPUs.
    - Then right before the attention computation, it employs all-to-all communication collective on the **partitioned queries, keys and values** such that each GPU receives the full sequence but only for a **non-overlapping subset of the attention heads**. This allows the participating GPUs to compute attention for different attention heads in parallel.
        - **gather_seq_scatter_heads**
    - Finally, DeepSpeed-Ulysses employs another all-to-all to **gather the results along the attention heads** while re-partitioning along the sequence dimension.
        - **gather_heads_scatter_seq**
- 将输入序列 X (长度 N) 沿序列维度切分为 SP 块，每个 GPU 分配到 N/SP 长度的子序列。
    - 对于非注意力层 (如 MLP)，计算是完全局部的，每个 GPU 处理自己的子序列即可。
        - token 之间独立，token-level projection
        - Ulysses SP的核心复杂性在于Attention层。为了让每个token在计算注意力时能够考虑到全局序列信息（或者说，让每个head在计算时能看到完整的序列，即使这个head只在当前rank计算），Attention模块前后需要进行两次精密的all-to-all数据重排。MLP层则没有这样的需求，数据在进入MLP时已经是按序列分片好的，可以直接进行本地计算。
    - 对于注意力层:
        - 步骤 1 (计算 Q, K, V): 每个 GPU 基于其本地子序列计算出本地的 Q_local, K_local, V_local (维度约为 N/SP x d，d 是隐藏维度)。
        - 步骤 2 (全局 K, V 收集 - 关键): 使用 **All-to-All** 通信操作（All-Gather??）。每个 GPU 将自己的 K_local, V_local 发送给所有其他 GPU，并接收来自所有其他 GPU 的 K, V 块。执行后，**每个 GPU 拥有完整的全局 K 和 V 矩阵 (维度 N x d)**，但仍然只拥有本地的 Q_local (维度 N/SP x d)。
            - https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html
        - 步骤 3 (本地注意力计算): 每个 GPU 使用其 Q_local 和完整的全局 K, V 计算其负责的那部分注意力输出 O_local (维度 N/SP x d)。计算公式为 Attention(Q_local, K_global, V_global)。这一步的计算量是 (N/SP) * N * d，内存瓶颈在于存储临时的注意力分数矩阵，大小约为 **(N/SP) * N**。相比原始的 **N*N**，内存显著降低。
        - 步骤 4 (可选的输出重组): 如果后续层需要按序列拼接的完整输出，可能需要另一次通信（如 All-Gather 或另一次 All-to-All 的变种）来组合 O_local。但在 DeepSpeed 实现中，通常保持分布式状态，直接输入到下一个同样按序列并行的层。

### verl sp

- `torchrun --nproc_per_node=2 -m pytest tests/model/test_transformers_ulysses.py -svv`
    - dp_size = world_size // sp_size
- monkey_patch
    - `_flash_attention_forward` => `_ulysses_flash_attention_forward`
    - 假设序列并行数 `ulysses_sp_size = N`。每个SP rank最初拥有 `(batch_size, seq_len / N, num_heads, head_dim)` 形状的 Q, K, V 张量。
        - gather_seq_scatter_heads
            - `[bsz, seq/n, h, ...] -> [bsz, seq, h/n, ...]` （for Q/K/V）
                - 得到完整的序列，部分的头；
        - flash-attn => `[bsz, seq, h/n, ...]`
        - gather_heads_scatter_seq
            - `[bsz, seq, h/n, ...] -> [bsz, seq/n, h, ...]`
                - 得到部分的序列，完整的头；
- 数据并行（fsdp）与 sp
    - fsdp：优化的是模型参数所占显存，sp：优化的是激活所占显存
    - fsdp: all-gather, reduce-scatter
    - sp: all-to-all

```
      SP=4 (列) -->
DP=2  GPU(0,0) GPU(0,1) GPU(0,2) GPU(0,3)  <-- DP Group 0 (Row 0)
(行)  GPU(1,0) GPU(1,1) GPU(1,2) GPU(1,3)  <-- DP Group 1 (Row 1)
 |
 V
```

In [4]:
import torch
import torch.nn.functional as F

# --- 参数设定 ---
batch_size = 1
seq_len = 12  # 总序列长度
d_model = 8   # 嵌入维度 (为了清晰起见保持较小)
num_devices = 3 # 模拟的设备/分块数量
chunk_len = seq_len // num_devices # 每个设备上的序列块长度

In [5]:
assert seq_len % num_devices == 0, "序列长度必须能被设备数量整除"

In [6]:
Q = torch.randn(batch_size, seq_len, d_model)
K = torch.randn(batch_size, seq_len, d_model)
V = torch.randn(batch_size, seq_len, d_model)

In [7]:
scale = d_model ** -0.5 # 缩放因子
# 计算注意力分数: Q @ K^T
attn_scores_standard = torch.matmul(Q, K.transpose(-2, -1)) * scale
# 应用 Softmax 获取注意力权重
attn_weights_standard = F.softmax(attn_scores_standard, dim=-1)
# 将权重应用于 V 得到输出
output_standard = torch.matmul(attn_weights_standard, V)

In [8]:
output_standard.shape

torch.Size([1, 12, 8])

### ring sa

In [9]:
Q_chunks = list(torch.chunk(Q, num_devices, dim=1))
K_chunks = list(torch.chunk(K, num_devices, dim=1))
V_chunks = list(torch.chunk(V, num_devices, dim=1))

print(f"Q 被切分成 {len(Q_chunks)} 块, 每块形状: {Q_chunks[0].shape}")
print(f"K 被切分成 {len(K_chunks)} 块, 每块形状: {K_chunks[0].shape}")
print(f"V 被切分成 {len(V_chunks)} 块, 每块形状: {V_chunks[0].shape}")

Q 被切分成 3 块, 每块形状: torch.Size([1, 4, 8])
K 被切分成 3 块, 每块形状: torch.Size([1, 4, 8])
V 被切分成 3 块, 每块形状: torch.Size([1, 4, 8])


In [14]:
# --- 2. Ring Self-Attention Simulation ---
print("\n--- Simulating Ring Self-Attention ---")

# Split tensors into chunks for each "device"
Q_chunks = list(torch.chunk(Q, num_devices, dim=1))
K_chunks = list(torch.chunk(K, num_devices, dim=1))
V_chunks = list(torch.chunk(V, num_devices, dim=1))

print(f"Split Q into {len(Q_chunks)} chunks, each shape: {Q_chunks[0].shape}")
print(f"Split K into {len(K_chunks)} chunks, each shape: {K_chunks[0].shape}")
print(f"Split V into {len(V_chunks)} chunks, each shape: {V_chunks[0].shape}")

output_chunks_rsa = []

# Simulate computation on each device
for i in range(num_devices):
    print(f"\n-- Simulating Device {i} --")
    q_local = Q_chunks[i] # Query chunk for this device
    ordered_scores = [None] * num_devices

    # Ring communication for Keys
    print(f"  Device {i} Q shape: {q_local.shape}")
    for j in range(num_devices):
        k_idx = (i - j + num_devices) % num_devices # Index of K chunk received in this step
        k_remote = K_chunks[k_idx]
        print(f"  Step {j}: Device {i} using K chunk from Device {k_idx} (Shape: {k_remote.shape})")

        # Calculate partial attention scores: Q_local @ K_remote^T
        scores_part = torch.matmul(q_local, k_remote.transpose(-2, -1)) * scale
        print(f"    Partial scores shape for K_{k_idx}: {scores_part.shape}")
        ordered_scores[k_idx] = scores_part

    # Concatenate partial scores in the correct order (k=0, 1, ..., N-1)
    all_scores_for_q_i = torch.cat(ordered_scores, dim=-1)
    print(f"  Device {i}: Concatenated scores shape (Correct Order): {all_scores_for_q_i.shape}") # Should be [batch, chunk_len, seq_len]

    # Apply Softmax
    attn_weights_for_q_i = F.softmax(all_scores_for_q_i, dim=-1)
    print(f"  Device {i}: Softmax weights shape: {attn_weights_for_q_i.shape}")

    # Apply weights to Value matrix (using reconstructed full V for equivalence check)
    full_V = torch.cat(V_chunks, dim=1) # Reconstruct full V for calculation
    output_chunk_i = torch.matmul(attn_weights_for_q_i, full_V)
    print(f"  Device {i}: Output chunk shape: {output_chunk_i.shape}") # Should be [batch, chunk_len, d_model]

    output_chunks_rsa.append(output_chunk_i)


--- Simulating Ring Self-Attention ---
Split Q into 3 chunks, each shape: torch.Size([1, 4, 8])
Split K into 3 chunks, each shape: torch.Size([1, 4, 8])
Split V into 3 chunks, each shape: torch.Size([1, 4, 8])

-- Simulating Device 0 --
  Device 0 Q shape: torch.Size([1, 4, 8])
  Step 0: Device 0 using K chunk from Device 0 (Shape: torch.Size([1, 4, 8]))
    Partial scores shape for K_0: torch.Size([1, 4, 4])
  Step 1: Device 0 using K chunk from Device 2 (Shape: torch.Size([1, 4, 8]))
    Partial scores shape for K_2: torch.Size([1, 4, 4])
  Step 2: Device 0 using K chunk from Device 1 (Shape: torch.Size([1, 4, 8]))
    Partial scores shape for K_1: torch.Size([1, 4, 4])
  Device 0: Concatenated scores shape (Correct Order): torch.Size([1, 4, 12])
  Device 0: Softmax weights shape: torch.Size([1, 4, 12])
  Device 0: Output chunk shape: torch.Size([1, 4, 8])

-- Simulating Device 1 --
  Device 1 Q shape: torch.Size([1, 4, 8])
  Step 0: Device 1 using K chunk from Device 1 (Shape: torc

In [15]:
# Concatenate the output chunks from all devices
output_rsa = torch.cat(output_chunks_rsa, dim=1) # Concatenate along the sequence dimension
print("\n--- RSA Result ---")
print("RSA Concatenated Output Shape:", output_rsa.shape)

# --- 3. Comparison ---
print("\n--- Comparison ---")
# Check if the results are numerically close
are_close = torch.allclose(output_standard, output_rsa, atol=1e-6) # Use a tolerance

print(f"Are Standard Attention and Ring Attention outputs equivalent? {are_close}")

# Verify the shapes match
assert output_standard.shape == output_rsa.shape, "Shapes do not match!"
if are_close:
    print("Success: The Ring Self-Attention simulation produced the same result as standard attention.")
else:
    print("Failure: The results differ.")
    # Optional: Print difference magnitude if they differ
    # diff = torch.abs(output_standard - output_rsa).max()
    # print(f"Maximum absolute difference: {diff.item()}")


--- RSA Result ---
RSA Concatenated Output Shape: torch.Size([1, 12, 8])

--- Comparison ---
Are Standard Attention and Ring Attention outputs equivalent? True
Success: The Ring Self-Attention simulation produced the same result as standard attention.
