Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FP8] performance degradation in speed and memory without compile #685

Closed
leeeizhang opened this issue Aug 15, 2024 · 2 comments
Closed

Comments

@leeeizhang
Copy link

leeeizhang commented Aug 15, 2024

The FP8 FFN performance degrades in both speed and GPU memory if it is not compiled.

Variances Torch FP16 AO FP8 (compile=False) AO FP8 (compile=True)
bs=32, seq=512, dim=512 0.92ms, 308MB 3.14ms, 594MB 0.95ms, 339MB
bs=32, seq=512, dim=1024 3.14ms, 664MB 7.17ms, 1.2GB 2.61ms, 724MB
bs=32, seq=512, dim=2048 11.38ms, 1.53GB 17.84ms, 2.6GB 7.84ms, 1.6GB
bs=32, seq=512, dim=4096 43.16ms, 3.9GB 49.25ms, 6.1GB 26.20ms, 4.1GB

Track Logs (torch.profile)

Image 1 Image 2
Compile=True (2.1ms) Compile=False (6.3ms)

Testbed

  • Torch: 2.5.0.dev20240814+cu121
  • TorchAO: 2024.8.15+cu121
  • CUDA Version: 12.1 (NVIDIA L20, SM89)

Codes to Reproduce Issues

"""
usage: $ python3 test.py --bs 32 --seq 512 --dim 1024 --compile 0
"""
import time
import argparse

import torch
import torch.nn as nn
from torchao.float8 import convert_to_float8_training


class TorchFFN(nn.Module):
    def __init__(self, in_feature, hidden_feature, bias=True):
        super().__init__()
        self.fc1 = nn.Linear(in_feature, hidden_feature, bias)
        self.fc2 = nn.Linear(hidden_feature, in_feature, bias)
        self.gelu = nn.GELU()

    def forward(self, x):
        x = self.fc1(x)
        x = self.gelu(x)
        x = self.fc2(x)
        return x


if __name__ == "__main__":
    torch.manual_seed(0)
    torch.cuda.set_device(0)

    parser = argparse.ArgumentParser()
    parser.add_argument('--bs', type=int, required=False, default=32)
    parser.add_argument('--seq', type=int, required=False, default=512)
    parser.add_argument('--dim', type=int, required=False, default=1024)
    parser.add_argument('--compile', type=int, required=False, default=0)
    args = parser.parse_args()

    # Test fp8 linear
    BS, SQ, DIM = args.bs, args.seq, args.dim
    x = torch.randn((BS, SQ, DIM), device="cuda")

    torch_fp16_ffn = TorchFFN(DIM, 4 * DIM).to("cuda")

    torch_fp8_ffn = TorchFFN(DIM, 4 * DIM).to("cuda")
    torch_fp8_ffn.load_state_dict(torch_fp16_ffn.state_dict())  # Align weights

    convert_to_float8_training(torch_fp8_ffn)
    if args.compile > 0:
        torch_fp8_ffn = torch.compile(torch_fp8_ffn)

    with torch.inference_mode():
        # Warmup
        for _ in range(10):
            _ = torch_fp8_ffn(x)
            with torch.autocast(device_type="cuda", dtype=torch.float16):
                _ = torch_fp16_ffn(x)

        # Test torch fp8 speed
        s = time.time()
        for _ in range(1000):
            torch_fp8_y = torch_fp8_ffn(x)
            torch.cuda.synchronize()
        e = time.time()
        print(f"torch fp8: {e-s}ms")

        # Test torch fp16 speed
        s = time.time()
        for _ in range(1000):
            with torch.autocast(device_type="cuda", dtype=torch.float16):
                torch_fp16_y = torch_fp16_ffn(x)
            torch.cuda.synchronize()
        e = time.time()
        print(f"torch fp16: {e-s}ms")

        # Profile memory
        torch.cuda.reset_peak_memory_stats("cuda:0"), torch.cuda.empty_cache()
        _ = torch_fp8_ffn(x)
        peak_memory = torch.cuda.max_memory_allocated("cuda:0")
        print(f"Torch FP8 Peak memory usage: {peak_memory / 1024 ** 2:.2f} MB")

        torch.cuda.reset_peak_memory_stats("cuda:0"), torch.cuda.empty_cache()
        with torch.autocast(device_type="cuda", dtype=torch.float16):
            torch_fp16_y = torch_fp16_ffn(x)
        peak_memory = torch.cuda.max_memory_allocated("cuda:0")
        print(f"Torch FP16 Peak memory usage: {peak_memory / 1024 ** 2:.2f} MB")

    print(f"[torch-fp8 v.s. torch-fp16] mse loss: {nn.functional.mse_loss(torch_fp16_y, torch_fp8_y)}")
@leeeizhang leeeizhang changed the title [BUG] performance degradation in speed and memory without FP8 compile [FP8] performance degradation in speed and memory without compile Aug 15, 2024
@msaroufim
Copy link
Member

I believe this is expected, in eager mode you'd be dispatching a kernel for a quant, matmul and dequant seperately wheras with compile you can do things like fuse the matmul and dequant into a single kernel. @vkuzo might have a longer form answer that applies to fp8

@drisspg
Copy link
Contributor

drisspg commented Aug 21, 2024

Yes this is expected, today we rely entirely on torch.compile for generating fused casting kernels

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants