Auto mixed precision    
Core idea: BF16 Exponent Bits 8 bits (Same as FP32)    

https://docs.qq.com/doc/DSUtuWFREUEFCbHZQ    
2.2.2

---
---

This guide covers things about modern Automatic Mixed Precision (AMP), from the fundamentals of data types to advanced industry standards like Megatron and FP8.

---

## üöÄ Introduction to Modern AMP

### What is the Core Idea?

The core idea of **AMP** is to use **different numerical precisions** for different operations during training to get the best of both worlds:

1. **Speed & Memory Efficiency:** Use lower precision (16-bit) for computationally heavy tasks like Matrix Multiplications ().
2. **Numerical Stability:** Use full precision (32-bit) for sensitive operations like Reductions () and for storing the "Master Weights."

By doing this, you can often **double your training speed** and **halve your memory usage** without losing model accuracy.

---

### Comparison: FP16 vs. BF16

Modern deep learning has shifted toward **BF16** (Brain Float 16) because it is much more robust for training than the older **FP16**.

| Feature | FP16 (Standard Half) | BF16 (Brain Float) |
| --- | --- | --- |
| **Exponent Bits** | 5 bits (Small range) | 8 bits (**Same as FP32**) |
| **Mantissa Bits** | 10 bits (Higher precision) | 7 bits (Lower precision) |
| **Max Value** | ~65,500 | ~ |
| **Gradient Scaling** | **Mandatory** (Avoids underflow) | **Optional/Unnecessary** |
| **Hardware** | Most GPUs (Volta+) | Modern GPUs (Ampere A100/30-series+) |

#### When to use BF16?

* **Always** if your hardware supports it (NVIDIA Ampere architecture or newer, like A100, H100, RTX 30/40 series).
* When training **Large Language Models (LLMs)**, as they are prone to "spiky" gradients that cause FP16 to overflow/diverge.

#### Does BF16 need adjustments (like Learning Rate)?

* **Learning Rate:** Generally, you **do not** need to change your learning rate when switching from FP32 to BF16. It is much more "plug-and-play" than FP16.
* **Epsilon ():** In optimizers like Adam, you might need to increase the epsilon value (e.g., from `1e-8` to `1e-5`) because the lower precision of the mantissa in BF16 can make very small numbers indistinguishable from zero.

---

## üõ† How AMP Influences Model Training

1. **Memory Footprint:** You can fit **larger models** or **larger batch sizes** on the same GPU.
2. **Throughput:** Training is significantly faster because modern Tensor Cores are optimized for 16-bit math.
3. **Stability:** * **FP16** requires a `GradScaler` to prevent gradients from becoming zero (underflow).
* **BF16** behaves almost exactly like FP32 training, making it highly stable.



---

## üè¢ Advanced Industry Methods

### 1. AMP in Megatron (NVIDIA)

**Megatron-LM** is a framework for training massive models. Its AMP implementation is highly optimized:

* **Master Weights:** It maintains a main copy of weights in `FP32` and a "buffer" copy in `FP16/BF16` for the actual forward/backward pass.
* **Distributed Optimizer:** It integrates AMP with data/model parallelism, ensuring that gradient synchronization (All-Reduce) happens in lower precision to save bandwidth, while weight updates happen in higher precision.

### 2. The Rise of FP8 (e.g., SGLang, Transformer Engine)

**FP8** (8-bit floating point) is the latest frontier for H100 (Hopper) GPUs.

* **Training:** Libraries like NVIDIA's **Transformer Engine** use FP8 for the forward and backward passes. This provides a massive speed boost (up to 3x over BF16).
* **Inference (SGLang):** Frameworks like **SGLang** use FP8 to compress models for serving. Because FP8 has a better dynamic range than INT8, it maintains higher accuracy for LLM inference while drastically reducing the KV cache memory and model size.

---

## üìù Summary: Which one should you use?

| Scenario | Recommended Method |
| --- | --- |
| **Standard GPU (Older)** | `torch.amp` with `float16` + `GradScaler` |
| **Modern GPU (A100/H100)** | `torch.amp` with `bfloat16` (No scaler needed) |
| **Massive LLM Training** | Megatron-LM or DeepSpeed with `BF16` |
| **Cutting-edge H100 Training** | `FP8` via Transformer Engine |

In [4]:
import torch
import torch.nn as nn

# 1. Setup device and modern AMP components
device = "cuda" if torch.cuda.is_available() else "cpu"
# For modern hardware (Ampere+), 'bfloat16' is often better than 'float16'
use_dtype = torch.bfloat16 if device == "cuda" else torch.float16 # For A100 and 30xx NVIDIA GPU

model = nn.Linear(10, 1).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# GradScaler is only needed for float16 (usually on CUDA)
# It's a no-op if disabled or using bfloat16
scaler = torch.amp.GradScaler(device, enabled=(use_dtype == torch.bfloat16)) ##### Important

# Dummy data
data = torch.randn(64, 10).to(device)
target = torch.randn(64, 1).to(device)

print(f"Starting training on {device} using {use_dtype}...")

for epoch in range(5):
    optimizer.zero_grad()

    # 2. Autocast context for forward pass
    with torch.amp.autocast(device_type=device, dtype=use_dtype): ##### Important
        output = model(data)
        loss = torch.nn.functional.mse_loss(output, target)

    # 3. Scaled Backward pass
    # scaler.scale(loss) multiplies loss by a scale factor
    scaler.scale(loss).backward() ##### Important

    # 4. Scaled Step
    # scaler.step() unscales the gradients and calls optimizer.step()
    scaler.step(optimizer) ##### Important

    # 5. Update the scale factor for the next iteration
    scaler.update() ##### Important

    print(f"Epoch {epoch}: Loss {loss.item():.4f}")

print("Training Complete!")

Starting training on cuda using torch.bfloat16...
Epoch 0: Loss 1.2899
Epoch 1: Loss 1.2702
Epoch 2: Loss 1.2508
Epoch 3: Loss 1.2337
Epoch 4: Loss 1.2178
Training Complete!
