-
Notifications
You must be signed in to change notification settings - Fork 349
Description
Hi Torch Team,
I am currently experimenting with native torch float8 training & comparing it to the Transformer Engine using the delayed scaling recipe on GPT 1.5B at batch=12 seq=1024 on 700W H100 SXM 80G SKU.
I see that fp8 transformer engine provides slight perf include compared to autocast bf16 but unfortunately torchao.float8 is almost 2x slower. I attempted to improve performance by trying to enable fp8 & using bf16 autocast at the same time but unfortunately I ran into ValueError: All layers must have the same last seen input_dtype, got {torch.float32, torch.bfloat16}
error. enabling fp8 & using bf16 autocast is something that TE does but not sure if it is needed for torchao.
Can you provide some guidance on how to improve performance on torchao.float8?
Thanks!
BF16 Autocast: 493.17 TFLOP/s
FP8 TE: 501.2 TFLOP/s
torchao.float8: 240.67 TFLOP/s
Reprod Script
import torch
import torch.nn as nn
from torchao.float8 import (
convert_to_float8_training,
sync_float8_amax_and_scale_history,
Float8LinearConfig,
ScalingType,
CastConfig,
)
import torch.nn.functional as F
import fire
class CausalSelfAttention(nn.Module):
def __init__(self, d_embd, n_heads, **kwargs):
super().__init__()
self.d_head = d_embd // n_heads # D
self.attn_proj = nn.Linear(d_embd, 3*d_embd)
self.out_proj = nn.Linear(d_embd, d_embd)
def forward(self, x_BTE):
qkv = self.attn_proj(x_BTE).split(x_BTE.size(-1), -1)
split_attn_head = lambda z: z.unflatten(-1, [-1, self.d_head]).transpose(1, 2)
q_BHTD, k_BHTD, v_BHTD = map(split_attn_head, qkv)
o_BHTD = F.scaled_dot_product_attention(q_BHTD, k_BHTD, v_BHTD, dropout_p=0.0, is_causal=True)
o_BTE = o_BHTD.transpose(1, 2).flatten(-2)
y_BTE = self.out_proj(o_BTE)
return y_BTE
class GPTBlock(nn.Module):
def __init__(self, d_embd, **kwargs):
super().__init__()
self.attn_norm = nn.LayerNorm(d_embd)
self.attn = CausalSelfAttention(d_embd, **kwargs)
self.ffn_norm = nn.LayerNorm(d_embd)
self.ffn = nn.Sequential(
nn.Linear(d_embd, 4*d_embd),
nn.GELU(),
nn.Linear(4*d_embd, d_embd)
)
def forward(self, x_BTE):
x_BTE = x_BTE + self.attn(self.attn_norm(x_BTE))
y_BTE = x_BTE + self.ffn(self.ffn_norm(x_BTE))
return y_BTE
class GPT(nn.Module):
def __init__(self, vocab_size, max_seq_len, n_layers, d_embd, **kwargs):
super().__init__()
self.tok_embd = nn.Embedding(vocab_size, d_embd)
self.pos_embd = nn.Embedding(max_seq_len, d_embd)
self.tsfmr_blks = nn.ModuleList(GPTBlock(d_embd, **kwargs) for _ in range(n_layers))
self.out_norm = nn.LayerNorm(d_embd)
def forward(self, idx_BT, **kwargs):
pos_T = torch.arange(idx_BT.size(1), dtype=torch.int64, device=idx_BT.device)
x_BTE = self.tok_embd(idx_BT) + self.pos_embd(pos_T).unsqueeze(0)
for tsfmr_blk in self.tsfmr_blks:
x_BTE = tsfmr_blk(x_BTE)
x_BTE = self.out_norm(x_BTE)
logits_BTV = x_BTE @ self.tok_embd.weight.T # Weight tying
return logits_BTV
# configure delayed scaling
config = Float8LinearConfig(
cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED),
cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED),
cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED),
# enable_amax_init=False, # only needed for autocast + compile + FSDP + float8 delayed
# enable_pre_and_post_forward=False # only needed for autocast + compile + FSDP + float8 delayed
)
def main(enable_fp8=True):
torch.manual_seed(3985)
torch.cuda.set_device(0)
# GPT 1.5B
cfg_json = {
"n_layers": 48,
"n_heads": 25,
"d_embd": 1600,
"max_seq_len": 1024,
"vocab_size": 50304,
"arch_name": "gpt"
}
model = GPT(**cfg_json).to('cuda:0')
N = sum(p.numel() for p in model.parameters()) # get param count
flops_per_iter = 6 * N * 16 * 1024
optimizer = torch.optim.AdamW(model.parameters(), fused=True)
if enable_fp8:
convert_to_float8_training(model, config=config)
model = torch.compile(model)
for step_idx in range(100):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
input_BT = torch.randint(50304, [12, 1024], dtype=torch.int64).to('cuda:0')
label_BT = torch.randint(50304, [12, 1024], dtype=torch.int64).to('cuda:0')
start.record()
if not enable_fp8:
with torch.amp.autocast('cuda', torch.bfloat16):
logits_BTV = model(input_BT)
loss = F.cross_entropy(logits_BTV.flatten(0, 1), label_BT.flatten())
else:
logits_BTV = model(input_BT)
loss = F.cross_entropy(logits_BTV.flatten(0, 1), label_BT.flatten())
loss.backward()
if enable_fp8:
sync_float8_amax_and_scale_history(model)
optimizer.step()
optimizer.zero_grad(set_to_none=True)
end.record()
torch.cuda.synchronize()
t = start.elapsed_time(end) / 1e3
flops_per_sec = flops_per_iter / t
print(f"finish {step_idx} step: {(flops_per_sec/1e12):.2f} TFLOP/s")
if __name__ == "__main__":
fire.Fire(main)
Dependencies
$ pip list | grep torch
pytorch-triton 3.1.0+cf34004b8a
torch 2.6.0.dev20241030+cu124
torch-tb-profiler 0.4.3
torchao 0.7.0.dev20241112+cu121