Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding fused uint4x2_mixed_mm to inductor #106516

Closed
wants to merge 9 commits into from
125 changes: 124 additions & 1 deletion test/inductor/test_pattern_matcher.py
Expand Up @@ -143,7 +143,7 @@ def fn(a, b):
self._test_mixed_impl(fn, args, True, False)

@inductor_config.patch(use_mixed_mm=True)
def test_mixed_mm_cpu_works(self):
def test_mixed_mm_cpu(self):
def fn(a, b):
return torch.mm(a, b.to(a.dtype))

Expand All @@ -153,6 +153,129 @@ def fn(a, b):
)
self._test_mixed_impl(fn, args, False, False)

@inductor_config.patch(use_mixed_mm=True)
def test_uint4x2_mixed_mm(self):
def fn(a, b):
return torch.mm(
a,
torch.cat((b & 0xF, b >> 4), 1)
.reshape(-1, b.shape[1])
.to(a.dtype)
.sub(8),
)

args_list = [
(
torch.randn(8, 8, device="cuda"),
torch.randint(0, 255, (4, 8), dtype=torch.uint8, device="cuda"),
),
(
torch.randn(8, 8, device="cuda"),
torch.randint(0, 255, (4, 8), dtype=torch.int32, device="cuda"),
),
(
torch.randn(8, 8, device="cuda"),
torch.randint(0, 255, (4, 8), dtype=torch.int64, device="cuda"),
),
]

for args in args_list:
torch._dynamo.reset()
counters.clear()
ref = fn(*args)
test, (code,) = run_and_get_code(torch.compile(fn), *args)
torch.testing.assert_close(ref, test)
self.assertTrue("uint4x2_mixed_mm" in code)

@inductor_config.patch(use_mixed_mm=True)
def test_uint4x2_mixed_mm_epi(self):
def fn(a, b, c, d):
return (
torch.mm(
a,
torch.cat((b & 0xF, b >> 4), 1)
.reshape(-1, b.shape[1])
.to(a.dtype)
.sub(8),
)
* c
+ d
)

args_list = [
(
torch.randn(8, 8, device="cuda"),
torch.randint(0, 255, (4, 8), dtype=torch.uint8, device="cuda"),
torch.randn(8, device="cuda"),
torch.randn(8, device="cuda"),
),
]

for args in args_list:
torch._dynamo.reset()
counters.clear()
ref = fn(*args)
test, (code,) = run_and_get_code(torch.compile(fn), *args)
torch.testing.assert_close(ref, test)
self.assertTrue("uint4x2_mixed_mm" in code)
self.assertTrue("fused_add_mm_mul" in code)

@inductor_config.patch(use_mixed_mm=True)
def test_uint4x2_mixed_mm_fail_to_match(self):
def fn(a, b):
return torch.mm(
a,
torch.cat((b & 0xF, b >> 4), 1)
.reshape(-1, b.shape[1])
.to(a.dtype)
.sub(8),
)

args_list = [
( # cpu
torch.randn(8, 8),
torch.randint(0, 255, (4, 8), dtype=torch.uint8),
),
( # int8
torch.randn(8, 8, device="cuda"),
torch.randint(-128, 127, (4, 8), dtype=torch.int8, device="cuda"),
), # we don't match for int8 since numerics
] # for int8 bitshifts don't match between triton and pytorch

for args in args_list:
torch._dynamo.reset()
counters.clear()
ref = fn(*args)
test, (code,) = run_and_get_code(torch.compile(fn), *args)
torch.testing.assert_close(ref, test)
self.assertFalse("uint4x2_mixed_mm" in code)

@inductor_config.patch(use_mixed_mm=False)
def test_uint4x2_mixed_mm_gating_works(self):
def fn(a, b):
return torch.mm(
a,
torch.cat((b & 0xF, b >> 4), 1)
.reshape(-1, b.shape[1])
.to(a.dtype)
.sub(8),
)

args_list = [
(
torch.randn(8, 8, device="cuda"),
torch.randint(0, 255, (4, 8), dtype=torch.uint8, device="cuda"),
),
]

for args in args_list:
torch._dynamo.reset()
counters.clear()
ref = fn(*args)
test, (code,) = run_and_get_code(torch.compile(fn), *args)
torch.testing.assert_close(ref, test)
self.assertFalse("uint4x2_mixed_mm" in code)

def test_addmm(self):
def fn(a, b, c):
return torch.add(a, torch.mm(b, c)), torch.mm(b, c) + a
Expand Down
20 changes: 20 additions & 0 deletions test/inductor/test_torchinductor.py
Expand Up @@ -1799,6 +1799,26 @@ def fn(a, b, scale, bias):
check_lowp=True,
)

@config.patch(use_mixed_mm=True)
def test_uint4x2_mixed_mm(self):
def fn(a, b):
return torch.mm(
a,
torch.cat((b & 0xF, b >> 4), 1)
.reshape(-1, b.shape[1])
.to(a.dtype)
.sub(8),
)

self.common(
fn,
(
torch.randn(8, 8),
torch.randint(0, 255, (4, 8), dtype=torch.uint8),
),
check_lowp=True,
)

def test_scalar_input(self):
def fn(x, y):
a = torch.div(x, y, rounding_mode="floor")
Expand Down
80 changes: 76 additions & 4 deletions torch/_inductor/fx_passes/post_grad.py
Expand Up @@ -155,6 +155,81 @@ def mm_plus_mm(match: Match, mat1, mat2, mat3, mat4):
return inductor.kernel.mm_plus_mm.tuned_mm_plus_mm(mat1, mat2, mat3, mat4) # type: ignore[attr-defined]


def cuda_and_enabled_mixed_mm(match):
return (config.use_mixed_mm or config.force_mixed_mm) and getattr(
match.kwargs["mat1"].meta.get("val"), "is_cuda", False
)


def cuda_and_enabled_mixed_mm_and_not_int8(match):
return (
cuda_and_enabled_mixed_mm(match)
and getattr(match.kwargs["mat1"].meta.get("val"), "is_cuda", False)
and getattr(match.kwargs["mat2"].meta.get("val"), "dtype", torch.int8)
!= torch.int8
) # bitshift numerics in triton and pytorch don't match for torch.int8


"""
this is intended to be used to unpack a [K,N] int4 tensor from a [K/2, N] uint4x2 tensor
(where the int4 and uint4x2 are represented with int8 and uint8 respectively)
where every other row of the int4 is packed with the row above it as:
uint4x2[k,n] = (8+int4[2*k,n])+(8+int4[2*k+1,n])<<4

unpack formulas:
int4[2*k,n]=(uint4x2[k,n] & 0xF) - 8
int4[2*k+1,n]=(uint4x2[k,n] >> 4) - 8

thus matching on unpack formula:
torch.mm(mat1, torch.cat((mat2 & 0xF, mat2>>4),1).reshape(mat2_mm_shape).to(mat2_dtype).sub(8))

note: although the unpack formula in pytorch and the triton kernel is designed for a uint8 mat2, the behavior
of the kernel matches the pytorch formula for all dtypes except torch.int8
where the bitwise numerics in triton do not match those in pytorch.
"""


@register_lowering_pattern(
CallFunction(
aten.mm.default,
KeywordArg("mat1"),
CallFunction(
aten.sub.Tensor,
CallFunction(
prims.convert_element_type.default,
CallFunction(
aten.reshape.default,
CallFunction(
aten.cat.default,
ListOf(
CallFunction(
aten.bitwise_and.Scalar,
KeywordArg("mat2"),
0xF,
),
CallFunction(
aten.__rshift__.Scalar,
KeywordArg("mat2"),
4,
),
),
1,
),
KeywordArg("mat2_mm_shape"),
),
KeywordArg("mat2_dtype"),
),
8,
),
),
extra_check=cuda_and_enabled_mixed_mm_and_not_int8,
)
def uint4x2_mixed_mm(match: Match, mat1, mat2, mat2_mm_shape, mat2_dtype):
return inductor.kernel.unpack_mixed_mm.tuned_uint4x2_mixed_mm(
mat1, mat2, mat2_mm_shape, mat2_dtype
)


"""
torch.mm(mat1, mat2.to(mat2_dtype))
"""
Expand All @@ -170,10 +245,7 @@ def mm_plus_mm(match: Match, mat1, mat2, mat3, mat4):
KeywordArg("mat2_dtype"),
),
),
extra_check=(
lambda match: (config.use_mixed_mm or config.force_mixed_mm)
and getattr(match.kwargs["mat1"].meta.get("val"), "is_cuda", False)
), # needs cuda
extra_check=cuda_and_enabled_mixed_mm,
)
def mixed_mm(match: Match, mat1, mat2, mat2_dtype):
return inductor.kernel.mm.tuned_mixed_mm(mat1, mat2, mat2_dtype)
Expand Down
4 changes: 3 additions & 1 deletion torch/_inductor/kernel/mm_common.py
Expand Up @@ -119,14 +119,16 @@ def mm_options(config, sym_k, layout, b_prologue_cast_type=None):
)


def mm_args(mat1, mat2, *others, layout=None, out_dtype=None):
def mm_args(mat1, mat2, *others, layout=None, out_dtype=None, use_4x2_dim=False):
"""
Common arg processing for mm,bmm,addmm,etc
"""
mat1, mat2 = realize_inputs(mat1, mat2)
*b1, m, k1 = mat1.get_size()
*b2, k2, n = mat2.get_size()
b = [V.graph.sizevars.guard_equals(a, b) for a, b in zip(b1, b2)]
if use_4x2_dim:
k2 = k2 * 2
k = V.graph.sizevars.guard_equals(k1, k2)
if layout is None:
from torch._inductor.ir import FixedLayout
Expand Down
81 changes: 81 additions & 0 deletions torch/_inductor/kernel/unpack_mixed_mm.py
@@ -0,0 +1,81 @@
import logging

from ..select_algorithm import autotune_select_algorithm, TritonTemplate
from .mm_common import mm_args, mm_configs, mm_grid, mm_options

log = logging.getLogger(__name__)

uint4x2_mixed_mm_template = TritonTemplate(
name="uint4x2_mixed_mm",
grid=mm_grid,
source=r"""
{{def_kernel("A", "B")}}
M = {{size("A", 0)}}
N = {{size("B", 1)}}
K = {{size("A", 1)}}
stride_am = {{stride("A", 0)}}
stride_ak = {{stride("A", 1)}}
stride_bk = {{stride("B", 0)}}
stride_bn = {{stride("B", 1)}}

# based on triton.ops.matmul
pid = tl.program_id(0)
grid_m = (M + BLOCK_M - 1) // BLOCK_M
grid_n = (N + BLOCK_N - 1) // BLOCK_N

# re-order program ID for better L2 performance
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size)

rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
rk = tl.arange(0, BLOCK_K)
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
B = B + (rk[:, None]//2 * stride_bk + rbn[None, :] * stride_bn)
b_shifts = 4*(rk%2)
b_subs = 8*(1-(rk%2))

acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
for k in range(K, 0, -BLOCK_K):
if EVEN_K:
a = tl.load(A)
b = tl.load(B)
else:
a = tl.load(A, mask=rk[None, :] < k, other=0.)
b = tl.load(B, mask=rk[:, None] < k, other=0.)
b = ((b >> b_shifts[:, None]) & 0xF) - 8
b = b.to(B_PROLOGUE_CAST_TYPE)
acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
A += BLOCK_K * stride_ak
B += BLOCK_K//2 * stride_bk

# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
idx_m = rm[:, None]
idx_n = rn[None, :]
mask = (idx_m < M) & (idx_n < N)

# inductor generates a suffix
{{store_output(("idx_m", "idx_n"), "acc", "mask")}}
""",
)


def tuned_uint4x2_mixed_mm(mat1, mat2, mat2_mm_shape, mat2_dtype):
m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=None, use_4x2_dim=True)
choices = []
b_prologue_cast_type = f"tl.{mat2_dtype}".replace("torch.", "")
for config in mm_configs(m, n, k):
uint4x2_mixed_mm_template.maybe_append_choice(
choices,
(mat1, mat2),
layout,
**mm_options(config, k, layout, b_prologue_cast_type),
)
return autotune_select_algorithm("uint4x2_mixed_mm", choices, [mat1, mat2], layout)