Skip to content

Commit

Permalink
Optimize reduction + amax fusion
Browse files Browse the repository at this point in the history
ghstack-source-id: f9abe7a11dafa5dd3095cd2d46029c029271424d
Pull Request resolved: #111122
  • Loading branch information
ipiszy committed Oct 16, 2023
1 parent f429757 commit 63c8990
Show file tree
Hide file tree
Showing 4 changed files with 403 additions and 26 deletions.
206 changes: 203 additions & 3 deletions test/inductor/test_fp8.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
# Owner(s): ["module: inductor"]

import functools
import unittest
from typing import Tuple

import torch
from torch import Tensor
from torch._dynamo.test_case import run_tests, TestCase
from torch._inductor import utils
from torch.testing._internal.common_cuda import SM90OrLater
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
Expand All @@ -15,6 +19,25 @@
torch.set_float32_matmul_precision("high")


# define the e4m3/e5m2 constants
E4M3_MAX_POS = 448.0
E5M2_MAX_POS = 57344.0


def _to_fp8_saturated(x: Tensor, float8_dtype: torch.dtype) -> Tensor:
# The default behavior in PyTorch for casting to `float8_e4m3fn`
# and `e5m2` is to not saturate. In this context, we should saturate.
# A common case where we want to saturate is when the history of a
# tensor has a maximum value of `amax1`, and the current amax value
# is `amax2`, where `amax1 < amax2`. This is common when using delayed
# scaling.
if float8_dtype == torch.float8_e4m3fn:
x = x.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS)
else:
x = x.clamp(min=-1 * E5M2_MAX_POS, max=E5M2_MAX_POS)
return x.to(float8_dtype)


@instantiate_parametrized_tests
class TestFP8Types(TestCase):
@unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM")
Expand Down Expand Up @@ -59,16 +82,16 @@ def fp8_matmul_unwrapped(x):
@unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM")
@unittest.skipIf(not SM90OrLater, "FP8 is only supported on H100+")
@parametrize("dtype", (torch.float16, torch.bfloat16, torch.float))
def test_valid_cast(self, dtype: torch.dtype):
@parametrize("shape", ((15, 3, 13), (4, 2048, 4096)))
def test_valid_cast(self, dtype: torch.dtype, shape: Tuple[int]):
def fp8_cast(x):
y0 = x.to(dtype=torch.float8_e4m3fn).to(dtype)
y1 = x.to(dtype=torch.float8_e5m2).to(dtype)
return y0, y1

compiled_fp8_cast = torch.compile(fp8_cast, backend="inductor", dynamic=True)

x_shape = (16, 16, 16)
x = torch.rand(*x_shape, device="cuda", dtype=dtype)
x = torch.rand(*shape, device="cuda", dtype=dtype)
y0_fp8, y1_fp8 = compiled_fp8_cast(x)

torch.testing.assert_close(y0_fp8, x, rtol=5e-1, atol=5e-1)
Expand Down Expand Up @@ -98,6 +121,183 @@ def fp8_cast(x, dtype):
x = torch.rand(*x_shape, device="cuda").to(dtype=torch.float8_e5m2)
y = compiled_fp8_cast(x, torch.float8_e4m3fn)

@unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM")
@unittest.skipIf(not SM90OrLater, "FP8 is only supported on H100+")
@parametrize("src_dtype", (torch.float16, torch.bfloat16, torch.float))
@parametrize("dst_dtype", (torch.float8_e4m3fn, torch.float8_e5m2))
@parametrize("shape", ((16, 16, 16), (4, 2048, 4096)))
def test_to_fp8_saturated(
self, src_dtype: torch.dtype, dst_dtype: torch.dtype, shape: Tuple[int]
):
def fp8_saturated(x, dtype):
return _to_fp8_saturated(x, dtype)

compiled_fp8_cast = torch.compile(
fp8_saturated, backend="inductor", dynamic=True
)
x = torch.rand(*shape, device="cuda", dtype=src_dtype)
y_compiled = compiled_fp8_cast(x, dst_dtype)
y = fp8_saturated(x, dst_dtype)

torch.testing.assert_close(y_compiled.half(), y.half(), rtol=5e-1, atol=5e-1)

@unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM")
@unittest.skipIf(not SM90OrLater, "FP8 is only supported on H100+")
@parametrize("float8_dtype", (torch.float8_e4m3fn, torch.float8_e5m2))
@parametrize(
"shape", ((1, 1, 15), (1, 10, 15), (1, 10, 512), (1, 10, 4096), (4, 2048, 4096))
)
def test_amax_fp8_quant(self, float8_dtype: torch.dtype, shape: Tuple[int]):
batch_size, sequence_length, hidden_size = shape

def amax_fp8(x: Tensor, scale: Tensor):
y = torch.amax(torch.abs(x))
y_scaled = y.to(dtype=torch.float) * scale
bits_fp8 = _to_fp8_saturated(y_scaled, float8_dtype)
return bits_fp8

compiled_amax_fp8_quant = torch.compile(amax_fp8, backend="inductor")

x_shape = (batch_size, sequence_length, hidden_size)
x = torch.rand(*x_shape, device="cuda", dtype=torch.half)
scale = torch.tensor(0.2, device="cuda", dtype=torch.float)

y_compiled = compiled_amax_fp8_quant(x, scale)
y = amax_fp8(x, scale)

torch.testing.assert_close(y_compiled.half(), y.half(), rtol=1e-2, atol=1e-2)

@unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM")
@unittest.skipIf(not SM90OrLater, "FP8 is only supported on H100+")
@parametrize("float8_dtype", (torch.float8_e4m3fn, torch.float8_e5m2))
@parametrize(
"shape", ((1, 1, 15), (1, 10, 15), (1, 10, 512), (1, 10, 4096), (4, 2048, 4096))
)
def test_amax_along_with_fp8_quant(
self, float8_dtype: torch.dtype, shape: Tuple[int]
):
batch_size, sequence_length, hidden_size = shape

def amax_fp8(x: Tensor, scale: Tensor, amax_buffer: Tensor):
amax_buffer.fill_(torch.amax(torch.abs(x)))
x_scaled = x.to(dtype=torch.float) * scale
bits_fp8 = _to_fp8_saturated(x_scaled, float8_dtype)
return bits_fp8

compiled_amax_fp8_quant = torch.compile(amax_fp8, backend="inductor")

x_shape = (batch_size, sequence_length, hidden_size)
x = torch.rand(*x_shape, device="cuda", dtype=torch.half)
scale = torch.tensor(1.0, device="cuda", dtype=torch.float)

amax_buffer_compiled = torch.zeros((1), device="cuda", dtype=torch.half)
y_compiled = compiled_amax_fp8_quant(x, scale, amax_buffer_compiled)
amax_buffer = torch.zeros((1), device="cuda", dtype=torch.half)
y = amax_fp8(x, scale, amax_buffer)

torch.testing.assert_close(y_compiled.half(), y.half(), rtol=1e-1, atol=1e-1)
torch.testing.assert_close(
amax_buffer_compiled, amax_buffer, rtol=1e-2, atol=1e-2
)

@unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM")
@unittest.skipIf(not SM90OrLater, "FP8 is only supported on H100+")
@parametrize("float8_dtype", (torch.float8_e4m3fn, torch.float8_e5m2))
@parametrize(
"shape", ((1, 1, 15), (1, 10, 15), (1, 10, 512), (1, 10, 4096), (4, 2048, 4096))
)
def test_layernorm_fp8_quant(self, float8_dtype: torch.dtype, shape: Tuple[int]):
batch_size, sequence_length, hidden_size = shape

def ln_fp8(x: Tensor, scale: Tensor, amax_buffer: Tensor):
x = torch.nn.functional.layer_norm(
x.to(dtype=torch.float),
[hidden_size],
weight=None,
bias=None,
eps=1e-05,
)
amax_buffer.fill_(torch.amax(torch.abs(x)))
x_scaled = x * scale
bits_fp8 = _to_fp8_saturated(x_scaled, float8_dtype)
return bits_fp8

compiled_ln_fp8_quant = torch.compile(ln_fp8, backend="inductor")

x_shape = (batch_size, sequence_length, hidden_size)
x = torch.rand(*x_shape, device="cuda", dtype=torch.half)
scale = torch.tensor(0.2, device="cuda", dtype=torch.float)

amax_buffer_compiled = torch.zeros((1), device="cuda", dtype=torch.half)
y_compiled = compiled_ln_fp8_quant(x, scale, amax_buffer_compiled)
amax_buffer = torch.zeros((1), device="cuda", dtype=torch.half)
y = ln_fp8(x, scale, amax_buffer)

torch.testing.assert_close(y_compiled.half(), y.half(), rtol=1e-1, atol=1e-1)
torch.testing.assert_close(
amax_buffer_compiled, amax_buffer, rtol=1e-2, atol=1e-2
)

@unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM")
@unittest.skipIf(not SM90OrLater, "FP8 is only supported on H100+")
@parametrize("float8_dtype", (torch.float8_e4m3fn, torch.float8_e5m2))
@parametrize("shape", ((4, 2048, 4096),))
def test_layernorm_fp8_quant_benchmark(
self,
float8_dtype: torch.dtype,
shape: Tuple[int],
):
batch_size, sequence_length, hidden_size = shape

def ln(x: Tensor):
x = torch.nn.functional.layer_norm(
x.to(dtype=torch.float),
[hidden_size],
weight=None,
bias=None,
eps=1e-05,
)
return x

def ln_fp8(x: Tensor, scale: Tensor, amax_buffer: Tensor):
x = torch.nn.functional.layer_norm(
x.to(dtype=torch.float),
[hidden_size],
weight=None,
bias=None,
eps=1e-05,
)
amax_buffer.fill_(torch.amax(torch.abs(x)))
x_scaled = x * scale
bits_fp8 = _to_fp8_saturated(x_scaled, float8_dtype)
return bits_fp8

compiled_ln_fp8_quant = torch.compile(ln_fp8, backend="inductor")

x_shape = (batch_size, sequence_length, hidden_size)
x = torch.rand(*x_shape, device="cuda", dtype=torch.half)
scale = torch.tensor(0.2, device="cuda", dtype=torch.float)

amax_buffer_compiled = torch.zeros((1), device="cuda", dtype=torch.half)
amax_buffer = torch.zeros((1), device="cuda", dtype=torch.half)
_ = compiled_ln_fp8_quant(x, scale, amax_buffer_compiled)
compiled_latency = utils.do_bench_using_profiling(
functools.partial(compiled_ln_fp8_quant, x, scale, amax_buffer_compiled)
)
eager_latency = utils.do_bench_using_profiling(
functools.partial(ln_fp8, x, scale, amax_buffer)
)

compiled_ln = torch.compile(ln, backend="inductor")
_ = compiled_ln(x)
ln_latency = utils.do_bench_using_profiling(functools.partial(compiled_ln, x))

print(
f"Config: {float8_dtype=}, {shape=}. "
f"Benchmark results: Inductor: {compiled_latency}ms, Eager: {eager_latency}ms, "
f"LN only Inductor: {ln_latency}ms."
)


if __name__ == "__main__":
if HAS_CUDA:
Expand Down
45 changes: 45 additions & 0 deletions torch/_inductor/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,5 +359,50 @@ def extract_read_writes(
)


def extract_input_node_reduction_ranges( # noqa: F722
input_node: ".ir.TensorBox", # type: ignore[valid-type] # noqa: F722
) -> Tuple[Optional[List[sympy.Expr]], Optional[List[sympy.Expr]]]:
"""
Returns the size and reduction size of all inputs, if the sizes and reduction_sizes (if exist) are all the same.
It's possible that a node has multiple inputs, some are Reduction nodes and others are Pointwise nodes.
In this case, reduction_sizes of the Reduction nodes need to be the same.
Otherwise returns (None, None).
"""

from .ir import ComputedBuffer, Loops

if not isinstance(input_node.data.data, Loops):
# Input node has already been realized. Return its size and reduction_size.
if hasattr(input_node, "get_size") and hasattr(
input_node, "get_reduction_size"
):
return (input_node.get_size(), input_node.get_reduction_size())
else:
return (None, None)

# There is one issue: what if there are views / permutations between the input node and its dependent realized nodes?
# The current method still uses reduction ranges from the dependent realized node, which is not ideal.
# Is there a way to check whether there are permutations inbetween?
reads = input_node.get_reads()
reduction_size = None
size = None
for read in reads:
if not isinstance(read, MemoryDep):
continue
buffer = V.graph.get_buffer(read.name)
if buffer is None:
continue
if isinstance(buffer, ComputedBuffer) and len(buffer.get_reduction_size()) > 0:
if reduction_size is None:
reduction_size = buffer.get_reduction_size()
size = buffer.get_size()
elif (
reduction_size != buffer.get_reduction_size()
or size != buffer.get_size()
):
return (None, None)
return (size, reduction_size)


def canonicalization_prefix():
return "c"

0 comments on commit 63c8990

Please sign in to comment.