# 我们将完成
1. 基准测试和性能分析框架
2. Flash Attention 
2. Triton 内核
3. 分布式数据并行训练
4. 优化器状态分片

# 1. 基准测试和性能分析

## 首先检查GPU是否可用

这边通过pyproject.toml下载的torch是cpu版本
如果要暗转GPU版本：
```bash
uv pip unstall torch 
uv pip install torch --index-url https://download.pytorch.org/whl/cu121
```
cu121是我的cuda版本为12.1，可以改为自己的cuda版本

In [1]:
import torch
import time
from typing import Callable

import math
import time
from cs336_basics.model import BasicsTransformerLM 
from cs336_basics.optimizer import get_cosine_lr
from cs336_basics.optimizer import AdamW
from cs336_basics.data import get_batch
import torch.nn.functional as F

In [None]:

print("Torch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())

if torch.cuda.is_available():
    print("CUDA device count:", torch.cuda.device_count())
    print("Current device:", torch.cuda.current_device())
    print("Device name:", torch.cuda.get_device_name(0))
else:
    print("No CUDA GPU detected")

print("MPS（苹果的MPS） available:", hasattr(torch.backends, "mps") and torch.backends.mps.is_available())

In [27]:
def mean(values: list[float]) -> float:
    if not values:
        raise ValueError("mean() requires at least one value")
    return sum(values) / len(values)

def benchmark(description: str, run: Callable, num_warmups: int = 1, num_trials: int = 3):
    """Benchmark `func` by running it `num_trials`, and return all the times."""
    # 热身：第一次运行可能较慢,因为要编译和缓存
    # 我们将多次要运行内核，因为重要的是稳态的运行时间。
    for _ in range(num_warmups):
        run()
    if torch.cuda.is_available():
        torch.cuda.synchronize()  # 等待 CUDA 线程完成（非常重要！）
    print('现在真正计时!')
    times: list[float] = [] # @inspect times, @inspect description
    for trial in range(num_trials):  # 多次重复
        start_time = time.time()
        run()  # 实际执行计算
        if torch.cuda.is_available():
            torch.cuda.synchronize()  # 等待 CUDA 线程 完成同步
        end_time = time.time()
        times.append((end_time - start_time) * 1000) # @inspect times
    mean_time = mean(times) # 多次测量取平均
    return mean_time

现在来测试一下sleep函数


In [None]:
benchmark("sleep", lambda : time.sleep(50 / 1000))

我们来实测一下transformer

In [11]:
vocab_size = 50_000
context_length = 128
batch_size = 32

d_model = 512
num_layers = 6
num_heads = 8
d_ff = 2048
rope_theta = 10000.0

max_lr = 3e-4
min_lr = 3e-5
warmup_iters = 200
cosine_cycle_iters = 10_000
num_steps = 10
device = "cuda" if torch.cuda.is_available() else "cpu"

model = BasicsTransformerLM(
    vocab_size=vocab_size,
    context_length=context_length,
    d_model=d_model,
    num_layers=num_layers,
    num_heads=num_heads,
    d_ff=d_ff,
    rope_theta=rope_theta,
).to(device)

optimizer = AdamW(
    model.parameters(),
    lr=max_lr,
    weight_decay=0.1,
)



def lm_loss(logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
    # logits: (B, T, V)
    # targets: (B, T)
    B, T, V = logits.shape
    return F.cross_entropy(
        logits.view(B * T, V),
        targets.view(B * T),
    )
import numpy as np

# 假设是 token 序列
dataset = np.arange(1000, dtype=np.int64)
x, y = get_batch(
    dataset=dataset,
    batch_size=4,
    context_length=8,
    device="cpu"
)

print(x.shape)  # (4, 8)
print(y.shape)  # (4, 8)
print(x)
print(y)


model.train()
print(f'设备：{device}')


torch.Size([4, 8])
torch.Size([4, 8])
tensor([[339, 340, 341, 342, 343, 344, 345, 346],
        [114, 115, 116, 117, 118, 119, 120, 121],
        [430, 431, 432, 433, 434, 435, 436, 437],
        [579, 580, 581, 582, 583, 584, 585, 586]])
tensor([[340, 341, 342, 343, 344, 345, 346, 347],
        [115, 116, 117, 118, 119, 120, 121, 122],
        [431, 432, 433, 434, 435, 436, 437, 438],
        [580, 581, 582, 583, 584, 585, 586, 587]])
设备：cuda


In [24]:
def run_model(num_steps = 1):
    for it in range(num_steps):
        print(f'{it}/{num_steps}')
        # 1️⃣ 更新学习率（每 step）
        lr = get_cosine_lr(
            it=it,
            max_learning_rate=max_lr,
            min_learning_rate=min_lr,
            warmup_iters=warmup_iters,
            cosine_cycle_iters=cosine_cycle_iters,
        )
        for param_group in optimizer.param_groups:
            param_group["lr"] = lr

        # 2️⃣ 取 batch
        x, y = get_batch(
            dataset=dataset,
            batch_size=batch_size,
            context_length=context_length,
            device=device,
        )

        # 3️⃣ 前向
        logits = model(x)

        # 4️⃣ loss
        loss = lm_loss(logits, y)

        # 5️⃣ 反向
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        
        # （可选）梯度裁剪（强烈推荐）
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        # 6️⃣ 更新参数
        optimizer.step()

    # 7️⃣ 日志
    if it % 100 == 0:
        print(
            f"step {it:6d} | "
            f"loss {loss.item():.4f} | "
            f"lr {lr:.2e}"
        )


In [21]:
run_model(1)

0/1
step      0 | loss 10.5185 | lr 0.00e+00


In [28]:
benchmark(description = 'transformer模型的基准测试', run =  run_model, num_warmups = 1, num_trials = 1)

0/1
step      0 | loss 10.5236 | lr 0.00e+00
现在真正计时!
0/1
step      0 | loss 10.5171 | lr 0.00e+00


3180.018663406372

## 推理：

In [5]:
model.eval()

with torch.no_grad():
    # 用 get_batch 拿一个 batch
    x, _ = get_batch(
        dataset=dataset,
        batch_size=1,
        context_length= 1,
        device=device,
    )

    # 取第一个样本作为 prompt
    prompt = x # shape: (context_length,)
    # 调用模型自带的 generate
    generated = model.generate(
        x=prompt,
        max_new_tokens=5,
        temperature=1.0,
        top_k=50,
        eos_token_id=None,
    )

    print(generated.shape)  # (<=50,)
    print(generated)


torch.Size([1, 5])
tensor([[ 6338, 39429, 49762,  2890, 38374]], device='cuda:0')


### 性能分析工具

In [33]:
from torch.profiler import ProfilerActivity
import os

In [34]:
def profile(description: str, run: Callable, num_warmups: int = 1, with_stack: bool = False):
    # 预热
    for _ in range(num_warmups):
        run()
    if torch.cuda.is_available():
        torch.cuda.synchronize()  # 等待CUDA线程结束
    
    # 使用性能分析器运行代码
    
    with torch.profiler.profile(
            activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
            # 输出堆栈跟踪以进行可视化
            with_stack=with_stack,
            #  需要导出堆栈跟踪以进行可视化
            experimental_config=torch._C._profiler._ExperimentalConfig(verbose=True)) as prof:
        run()
        if torch.cuda.is_available():
            torch.cuda.synchronize()  # 等待CUDA线程结束
    # 打印表格
    table = prof.key_averages().table(sort_by="cuda_time_total",
                                      max_name_column_width=80,
                                      row_limit=10)
    #text(f"## {description}")
    #text(table, verbatim=True)
    # Write stack trace visualization
    if with_stack:
        os.makedirs("var", exist_ok=True)
        text_path = f"var/stacks_{description}.txt"
        svg_path = f"var/stacks_{description}.svg"
        prof.export_stacks(text_path, "self_cuda_time_total")
    return table

In [None]:
tabel = profile(description ='transformer' , run = run_model, num_warmups = 1, with_stack = True)

0/1
step      0 | loss 10.5191 | lr 0.00e+00
0/1
step      0 | loss 10.5105 | lr 0.00e+00


In [None]:
print(tabel)

------------------------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                                    Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
------------------------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                               aten::bmm         1.98%       7.168ms         1.98%       7.168ms      43.440us     161.452ms        33.52%     161.452ms     978.497us           165  
                       autograd::engine::evaluate_function: BmmBackward0         0.36%       1.322ms         7.39%      26.801ms     487.29

# Flash Attention 