Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions benchmarks/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,11 @@ class RunResult:
"examples.low_mem_dropout",
"low_mem_dropout_tritonbench",
),
"bf16xint16_gemm": (
"tritonbench.operators.bf16xint16_gemm.bf16xint16_gemm",
"examples.bf16xint16_gemm",
"bf16xint16_gemm_tritonbench",
),
}


Expand Down Expand Up @@ -551,6 +556,15 @@ class RunResult:
"helion_low_mem_dropout_tritonbench-accuracy": "helion_accuracy",
"helion_low_mem_dropout_tritonbench-speedup": "helion_speedup",
},
"bf16xint16_gemm": {
"bf16xbf16": "baseline",
"bf16xint16-speedup": "triton_speedup",
"bf16xint16-accuracy": "triton_accuracy",
"torch_compile_bf16xbf16-speedup": "torch_compile_speedup",
"torch_compile_bf16xbf16-accuracy": "torch_compile_accuracy",
"helion_bf16xint16_gemm_tritonbench-speedup": "helion_speedup",
"helion_bf16xint16_gemm_tritonbench-accuracy": "helion_accuracy",
},
}


Expand Down
169 changes: 169 additions & 0 deletions examples/bf16xint16_gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
"""
BF16 x INT16 GEMM with Helion
============================================================
The kernel performs matrix multiplication where one matrix is in bfloat16 format and the other is in int16 format.
The int16 values are converted to bfloat16 before performing the matrix multiplication.
"""

# %%
from __future__ import annotations

from typing import Callable

import torch
from torch import Tensor

import helion
import helion.language as hl


# %%
@helion.kernel(static_shapes=True)
def _bf16xint16_gemm(x: Tensor, w: Tensor) -> Tensor:
"""
x is bf16, w is int16.
"""
M, K = x.shape
K2, N = w.shape
assert K == K2, f"size mismatch {K} != {K2}"

out = torch.empty([M, N], dtype=torch.bfloat16, device=x.device)

for tile_m, tile_n in hl.tile([M, N]):
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
for tile_k in hl.tile(K):
x_tile = x[tile_m, tile_k]
w_tile = w[tile_k, tile_n].to(torch.bfloat16)
acc = hl.dot(x_tile, w_tile, acc=acc)
out[tile_m, tile_n] = acc.to(torch.bfloat16)

return out


# %%
@helion.kernel(static_shapes=True)
def _int16xbf16_gemm(x: Tensor, w: Tensor) -> Tensor:
"""
x is int16, w is bf16.
"""
M, K = x.shape
K2, N = w.shape
assert K == K2, f"size mismatch {K} != {K2}"

out = torch.empty([M, N], dtype=torch.bfloat16, device=x.device)

for tile_m, tile_n in hl.tile([M, N]):
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
for tile_k in hl.tile(K):
x_tile = x[tile_m, tile_k].to(torch.bfloat16)
w_tile = w[tile_k, tile_n]
acc = hl.dot(x_tile, w_tile, acc=acc)
out[tile_m, tile_n] = acc.to(torch.bfloat16)

return out


# %%
def bf16xint16_gemm(x: Tensor, w: Tensor, transpose: bool = False) -> Tensor:
"""
This function dispatches to the appropriate kernel based on the transpose flag.

Args:
x (Tensor): Input tensor.
w (Tensor): Weight tensor.
transpose (bool): If True, assumes x is int16 and w is bf16. Default: False.

Returns:
Tensor: Output tensor in bfloat16 format.
"""
if transpose:
return _int16xbf16_gemm(x, w)
return _bf16xint16_gemm(x, w)


# %%
def bf16xint16_gemm_tritonbench(
tb_op: object, x: torch.Tensor, w: torch.Tensor
) -> Callable[[], torch.Tensor]:
"""
Wrapper for TritonBench compatibility.

Args:
tb_op: TritonBench operator instance
x (torch.Tensor): Input tensor in bfloat16 format.
w (torch.Tensor): Weight tensor in int16 format.

Returns:
Callable that returns output tensor in bfloat16 format.
"""
# Check if transpose mode based on tritonbench operator
transpose = getattr(tb_op, "transpose", False)

def run_kernel() -> torch.Tensor:
return bf16xint16_gemm(x, w, transpose=transpose)

return run_kernel


# %%
def reference_bf16xint16_pytorch(
x: torch.Tensor, w: torch.Tensor, transpose: bool = False
) -> torch.Tensor:
"""
Reference implementation using PyTorch operations.

Args:
x (torch.Tensor): Input tensor.
w (torch.Tensor): Weight tensor.
transpose (bool): Transpose mode flag.

Returns:
torch.Tensor: Output tensor in bfloat16 format.
"""
if transpose:
x_bf16 = x.to(torch.bfloat16)
return torch.matmul(x_bf16, w)
w_bf16 = w.to(torch.bfloat16)
return torch.matmul(x, w_bf16)


# %%
def check(m: int, k: int, n: int) -> None:
"""
Test the bf16 x int16 GEMM implementation against the PyTorch reference.

Args:
m (int): Number of rows.
k (int): Shared dimension.
n (int): Number of cols.
"""
x = torch.randn([m, k], device="cuda", dtype=torch.bfloat16)
w = torch.randint(-(2**15), 2**15 - 1, (k, n), device="cuda", dtype=torch.int16)

result = bf16xint16_gemm(x, w, transpose=False)
expected = reference_bf16xint16_pytorch(x, w, transpose=False)
torch.testing.assert_close(result, expected, rtol=1e-2, atol=1e-2)

x_int16 = torch.randint(
-(2**15), 2**15 - 1, (m, k), device="cuda", dtype=torch.int16
)
w_bf16 = torch.randn([k, n], device="cuda", dtype=torch.bfloat16)

result = bf16xint16_gemm(x_int16, w_bf16, transpose=True)
expected = reference_bf16xint16_pytorch(x_int16, w_bf16, transpose=True)
torch.testing.assert_close(result, expected, rtol=1e-2, atol=1e-2)


# %%
def main() -> None:
"""
Main entry point that runs the bf16xint16 kernel verification with different tensor sizes.
"""
check(256, 256, 256)
check(512, 512, 512)
check(65536, 1024, 1280)


# %%
if __name__ == "__main__":
main()
86 changes: 86 additions & 0 deletions test/test_examples.expected
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,92 @@ def attention(q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor, *, _la
_launcher(_helion_attention, (32 * triton.cdiv(512, _BLOCK_SIZE_1),), q_view, k_view, v_view, out, _BLOCK_SIZE_1, _RDIM_SIZE_2, 1, _BLOCK_SIZE_3, num_warps=4, num_stages=2)
return out.view(q_in.size())

--- assertExpectedJournal(TestExamples.test_bf16xint16)
from __future__ import annotations

import torch
import triton
import triton.language as tl
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _helion__bf16xint16_gemm(x, w, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
num_blocks_0 = tl.cdiv(65536, _BLOCK_SIZE_0)
pid_0 = tl.program_id(0) % num_blocks_0
pid_1 = tl.program_id(0) // num_blocks_0
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
offset_1 = pid_1 * _BLOCK_SIZE_1
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32)
for offset_2 in tl.range(0, 1024, _BLOCK_SIZE_2):
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
acc_copy = acc
acc_copy_0 = acc_copy
x_tile = tl.load(x + (indices_0[:, None] * 1024 + indices_2[None, :] * 1), None)
load_1 = tl.load(w + (indices_2[:, None] * 1280 + indices_1[None, :] * 1), None)
v_0 = tl.cast(load_1, tl.bfloat16)
acc = tl.dot(tl.cast(x_tile, tl.bfloat16), tl.cast(v_0, tl.bfloat16), acc=acc_copy_0, input_precision='tf32', out_dtype=tl.float32)
v_1 = tl.cast(acc, tl.bfloat16)
tl.store(out + (indices_0[:, None] * 1280 + indices_1[None, :] * 1), v_1, None)

def _bf16xint16_gemm(x: Tensor, w: Tensor, *, _launcher=_default_launcher):
"""
x is bf16, w is int16.
"""
M, K = x.shape
K2, N = w.shape
assert K == K2, f'size mismatch {K} != {K2}'
out = torch.empty([M, N], dtype=torch.bfloat16, device=x.device)
_BLOCK_SIZE_0 = 16
_BLOCK_SIZE_1 = 16
_BLOCK_SIZE_2 = 16
_launcher(_helion__bf16xint16_gemm, (triton.cdiv(65536, _BLOCK_SIZE_0) * triton.cdiv(1280, _BLOCK_SIZE_1),), x, w, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2)
return out

--- assertExpectedJournal(TestExamples.test_bf16xint16)
from __future__ import annotations

import torch
import triton
import triton.language as tl
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _helion__int16xbf16_gemm(x, w, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
num_blocks_0 = tl.cdiv(65536, _BLOCK_SIZE_0)
pid_0 = tl.program_id(0) % num_blocks_0
pid_1 = tl.program_id(0) // num_blocks_0
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
offset_1 = pid_1 * _BLOCK_SIZE_1
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32)
for offset_2 in tl.range(0, 1024, _BLOCK_SIZE_2):
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
acc_copy = acc
acc_copy_0 = acc_copy
load = tl.load(x + (indices_0[:, None] * 1024 + indices_2[None, :] * 1), None)
v_0 = tl.cast(load, tl.bfloat16)
w_tile = tl.load(w + (indices_2[:, None] * 1280 + indices_1[None, :] * 1), None)
acc = tl.dot(tl.cast(v_0, tl.bfloat16), tl.cast(w_tile, tl.bfloat16), acc=acc_copy_0, input_precision='tf32', out_dtype=tl.float32)
v_1 = tl.cast(acc, tl.bfloat16)
tl.store(out + (indices_0[:, None] * 1280 + indices_1[None, :] * 1), v_1, None)

def _int16xbf16_gemm(x: Tensor, w: Tensor, *, _launcher=_default_launcher):
"""
x is int16, w is bf16.
"""
M, K = x.shape
K2, N = w.shape
assert K == K2, f'size mismatch {K} != {K2}'
out = torch.empty([M, N], dtype=torch.bfloat16, device=x.device)
_BLOCK_SIZE_0 = 16
_BLOCK_SIZE_1 = 16
_BLOCK_SIZE_2 = 16
_launcher(_helion__int16xbf16_gemm, (triton.cdiv(65536, _BLOCK_SIZE_0) * triton.cdiv(1280, _BLOCK_SIZE_1),), x, w, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2)
return out

--- assertExpectedJournal(TestExamples.test_bmm)
from __future__ import annotations

Expand Down
32 changes: 32 additions & 0 deletions test/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,38 @@ def test_low_mem_dropout(self):
check_example("low_mem_dropout", (p, grad_y, seed), grad_x),
)

@skipIfRocm("precision differences with bf16xint16 operations on rocm")
def test_bf16xint16(self):
from examples.bf16xint16_gemm import reference_bf16xint16_pytorch

m, k, n = 65536, 1024, 1280

x = torch.randn([m, k], device=DEVICE, dtype=torch.bfloat16)
w = torch.randint(-(2**15), 2**15 - 1, (k, n), device=DEVICE, dtype=torch.int16)

self.assertExpectedJournal(
check_example(
"bf16xint16_gemm",
(x, w),
reference_bf16xint16_pytorch(x, w, False),
fn_name="_bf16xint16_gemm",
)
)

x_int16 = torch.randint(
-(2**15), 2**15 - 1, (m, k), device=DEVICE, dtype=torch.int16
)
w_bf16 = torch.randn([k, n], device=DEVICE, dtype=torch.bfloat16)

self.assertExpectedJournal(
check_example(
"bf16xint16_gemm",
(x_int16, w_bf16),
reference_bf16xint16_pytorch(x_int16, w_bf16, True),
fn_name="_int16xbf16_gemm",
)
)

def test_rms_norm_fwd(self):
args = (
torch.randn([128, 256], device=DEVICE, dtype=torch.float16),
Expand Down
Loading