Skip to content

[float8] FP8 GPT 1.5B Delayed Scaling 2x Slower than BF16  #1297

@functionstackx

Description

@functionstackx

Hi Torch Team,

I am currently experimenting with native torch float8 training & comparing it to the Transformer Engine using the delayed scaling recipe on GPT 1.5B at batch=12 seq=1024 on 700W H100 SXM 80G SKU.

I see that fp8 transformer engine provides slight perf include compared to autocast bf16 but unfortunately torchao.float8 is almost 2x slower. I attempted to improve performance by trying to enable fp8 & using bf16 autocast at the same time but unfortunately I ran into ValueError: All layers must have the same last seen input_dtype, got {torch.float32, torch.bfloat16} error. enabling fp8 & using bf16 autocast is something that TE does but not sure if it is needed for torchao.

Can you provide some guidance on how to improve performance on torchao.float8?

Thanks!

BF16 Autocast: 493.17 TFLOP/s
FP8 TE: 501.2 TFLOP/s
torchao.float8: 240.67 TFLOP/s 

Reprod Script

import torch
import torch.nn as nn
from torchao.float8 import (
    convert_to_float8_training,
    sync_float8_amax_and_scale_history,
    Float8LinearConfig,
    ScalingType,
    CastConfig,
)
import torch.nn.functional as F
import fire

class CausalSelfAttention(nn.Module):
    def __init__(self, d_embd, n_heads, **kwargs):
        super().__init__()
        self.d_head = d_embd // n_heads  # D
        self.attn_proj = nn.Linear(d_embd, 3*d_embd)
        self.out_proj = nn.Linear(d_embd, d_embd)
 
    def forward(self, x_BTE):
        qkv = self.attn_proj(x_BTE).split(x_BTE.size(-1), -1)
        split_attn_head = lambda z: z.unflatten(-1, [-1, self.d_head]).transpose(1, 2)
        q_BHTD, k_BHTD, v_BHTD = map(split_attn_head, qkv)
        o_BHTD = F.scaled_dot_product_attention(q_BHTD, k_BHTD, v_BHTD, dropout_p=0.0, is_causal=True)
        o_BTE = o_BHTD.transpose(1, 2).flatten(-2)
        y_BTE = self.out_proj(o_BTE)
        return y_BTE

class GPTBlock(nn.Module):
    def __init__(self, d_embd, **kwargs):
        super().__init__()
        self.attn_norm = nn.LayerNorm(d_embd)
        self.attn = CausalSelfAttention(d_embd, **kwargs)
        self.ffn_norm = nn.LayerNorm(d_embd)
        self.ffn = nn.Sequential(
            nn.Linear(d_embd, 4*d_embd),
            nn.GELU(),
            nn.Linear(4*d_embd, d_embd)
        )

    def forward(self, x_BTE):
        x_BTE = x_BTE + self.attn(self.attn_norm(x_BTE))
        y_BTE = x_BTE + self.ffn(self.ffn_norm(x_BTE))
        return y_BTE

class GPT(nn.Module):
    def __init__(self, vocab_size, max_seq_len, n_layers, d_embd, **kwargs):
        super().__init__()
        self.tok_embd = nn.Embedding(vocab_size, d_embd)
        self.pos_embd = nn.Embedding(max_seq_len, d_embd)
        self.tsfmr_blks = nn.ModuleList(GPTBlock(d_embd, **kwargs) for _ in range(n_layers))
        self.out_norm = nn.LayerNorm(d_embd)

    def forward(self, idx_BT, **kwargs):
        pos_T = torch.arange(idx_BT.size(1), dtype=torch.int64, device=idx_BT.device)
        x_BTE = self.tok_embd(idx_BT) + self.pos_embd(pos_T).unsqueeze(0)

        for tsfmr_blk in self.tsfmr_blks:
            x_BTE = tsfmr_blk(x_BTE)

        x_BTE = self.out_norm(x_BTE)
        logits_BTV = x_BTE @ self.tok_embd.weight.T  # Weight tying

        return logits_BTV

# configure delayed scaling
config = Float8LinearConfig(
    cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED),
    cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED),
    cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED),
    # enable_amax_init=False,  # only needed for autocast + compile + FSDP +  float8 delayed
    # enable_pre_and_post_forward=False  # only needed for autocast + compile + FSDP +  float8 delayed
)

def main(enable_fp8=True):
    torch.manual_seed(3985)
    torch.cuda.set_device(0)
    
    # GPT 1.5B
    cfg_json = {
        "n_layers": 48,
        "n_heads": 25,
        "d_embd": 1600,
        "max_seq_len": 1024,
        "vocab_size": 50304,
        "arch_name": "gpt"
    }
    model = GPT(**cfg_json).to('cuda:0')
    
    N = sum(p.numel() for p in model.parameters())  # get param count

    flops_per_iter = 6 * N * 16 * 1024
    
    optimizer = torch.optim.AdamW(model.parameters(), fused=True)

    if enable_fp8:
        convert_to_float8_training(model, config=config)

    model = torch.compile(model)
    
    for step_idx in range(100):
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        input_BT = torch.randint(50304, [12, 1024], dtype=torch.int64).to('cuda:0')
        label_BT = torch.randint(50304, [12, 1024], dtype=torch.int64).to('cuda:0')

        start.record()
        if not enable_fp8:
            with torch.amp.autocast('cuda', torch.bfloat16):
                logits_BTV = model(input_BT)
                loss = F.cross_entropy(logits_BTV.flatten(0, 1), label_BT.flatten())
        else:
            logits_BTV = model(input_BT)
            loss = F.cross_entropy(logits_BTV.flatten(0, 1), label_BT.flatten())
        loss.backward()
        if enable_fp8:
            sync_float8_amax_and_scale_history(model)

        optimizer.step()
        optimizer.zero_grad(set_to_none=True)
        end.record()

        torch.cuda.synchronize()
        t = start.elapsed_time(end) / 1e3
        flops_per_sec = flops_per_iter / t
        print(f"finish {step_idx} step: {(flops_per_sec/1e12):.2f} TFLOP/s")

if __name__ == "__main__":
    fire.Fire(main)

Dependencies

$ pip list | grep torch
pytorch-triton               3.1.0+cf34004b8a
torch                        2.6.0.dev20241030+cu124
torch-tb-profiler            0.4.3
torchao                      0.7.0.dev20241112+cu121

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions