Skip to content
125 changes: 124 additions & 1 deletion test/inductor/test_pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def fn(a, b):
self._test_mixed_impl(fn, args, True, False)

@inductor_config.patch(use_mixed_mm=True)
def test_mixed_mm_cpu_works(self):
def test_mixed_mm_cpu(self):
def fn(a, b):
return torch.mm(a, b.to(a.dtype))

Expand All @@ -153,6 +153,129 @@ def fn(a, b):
)
self._test_mixed_impl(fn, args, False, False)

@inductor_config.patch(use_mixed_mm=True)
def test_uint4x2_mixed_mm(self):
def fn(a, b):
return torch.mm(
a,
torch.cat((b & 0xF, b >> 4), 1)
.reshape(-1, b.shape[1])
.to(a.dtype)
.sub(8),
)

args_list = [
(
torch.randn(8, 8, device="cuda"),
torch.randint(0, 255, (4, 8), dtype=torch.uint8, device="cuda"),
),
(
torch.randn(8, 8, device="cuda"),
torch.randint(0, 255, (4, 8), dtype=torch.int32, device="cuda"),
),
(
torch.randn(8, 8, device="cuda"),
torch.randint(0, 255, (4, 8), dtype=torch.int64, device="cuda"),
),
]

for args in args_list:
torch._dynamo.reset()
counters.clear()
ref = fn(*args)
test, (code,) = run_and_get_code(torch.compile(fn), *args)
torch.testing.assert_close(ref, test)
self.assertTrue("uint4x2_mixed_mm" in code)

@inductor_config.patch(use_mixed_mm=True)
def test_uint4x2_mixed_mm_epi(self):
def fn(a, b, c, d):
return (
torch.mm(
a,
torch.cat((b & 0xF, b >> 4), 1)
.reshape(-1, b.shape[1])
.to(a.dtype)
.sub(8),
)
* c
+ d
)

args_list = [
(
torch.randn(8, 8, device="cuda"),
torch.randint(0, 255, (4, 8), dtype=torch.uint8, device="cuda"),
torch.randn(8, device="cuda"),
torch.randn(8, device="cuda"),
),
]

for args in args_list:
torch._dynamo.reset()
counters.clear()
ref = fn(*args)
test, (code,) = run_and_get_code(torch.compile(fn), *args)
torch.testing.assert_close(ref, test)
self.assertTrue("uint4x2_mixed_mm" in code)
self.assertTrue("fused_add_mm_mul" in code)

@inductor_config.patch(use_mixed_mm=True)
def test_uint4x2_mixed_mm_fail_to_match(self):
def fn(a, b):
return torch.mm(
a,
torch.cat((b & 0xF, b >> 4), 1)
.reshape(-1, b.shape[1])
.to(a.dtype)
.sub(8),
)

args_list = [
( # cpu
torch.randn(8, 8),
torch.randint(0, 255, (4, 8), dtype=torch.uint8),
),
( # int8
torch.randn(8, 8, device="cuda"),
torch.randint(-128, 127, (4, 8), dtype=torch.int8, device="cuda"),
), # we don't match for int8 since numerics
] # for int8 bitshifts don't match between triton and pytorch

for args in args_list:
torch._dynamo.reset()
counters.clear()
ref = fn(*args)
test, (code,) = run_and_get_code(torch.compile(fn), *args)
torch.testing.assert_close(ref, test)
self.assertFalse("uint4x2_mixed_mm" in code)

@inductor_config.patch(use_mixed_mm=False)
def test_uint4x2_mixed_mm_gating_works(self):
def fn(a, b):
return torch.mm(
a,
torch.cat((b & 0xF, b >> 4), 1)
.reshape(-1, b.shape[1])
.to(a.dtype)
.sub(8),
)

args_list = [
(
torch.randn(8, 8, device="cuda"),
torch.randint(0, 255, (4, 8), dtype=torch.uint8, device="cuda"),
),
]

for args in args_list:
torch._dynamo.reset()
counters.clear()
ref = fn(*args)
test, (code,) = run_and_get_code(torch.compile(fn), *args)
torch.testing.assert_close(ref, test)
self.assertFalse("uint4x2_mixed_mm" in code)

def test_addmm(self):
def fn(a, b, c):
return torch.add(a, torch.mm(b, c)), torch.mm(b, c) + a
Expand Down
20 changes: 20 additions & 0 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1799,6 +1799,26 @@ def fn(a, b, scale, bias):
check_lowp=True,
)

@config.patch(use_mixed_mm=True)
def test_uint4x2_mixed_mm(self):
def fn(a, b):
return torch.mm(
a,
torch.cat((b & 0xF, b >> 4), 1)
.reshape(-1, b.shape[1])
.to(a.dtype)
.sub(8),
)

self.common(
fn,
(
torch.randn(8, 8),
torch.randint(0, 255, (4, 8), dtype=torch.uint8),
),
check_lowp=True,
)

def test_scalar_input(self):
def fn(x, y):
a = torch.div(x, y, rounding_mode="floor")
Expand Down
80 changes: 76 additions & 4 deletions torch/_inductor/fx_passes/post_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,81 @@ def mm_plus_mm(match: Match, mat1, mat2, mat3, mat4):
return inductor.kernel.mm_plus_mm.tuned_mm_plus_mm(mat1, mat2, mat3, mat4) # type: ignore[attr-defined]


def cuda_and_enabled_mixed_mm(match):
return (config.use_mixed_mm or config.force_mixed_mm) and getattr(
match.kwargs["mat1"].meta.get("val"), "is_cuda", False
)


def cuda_and_enabled_mixed_mm_and_not_int8(match):
return (
cuda_and_enabled_mixed_mm(match)
and getattr(match.kwargs["mat1"].meta.get("val"), "is_cuda", False)
and getattr(match.kwargs["mat2"].meta.get("val"), "dtype", torch.int8)
!= torch.int8
) # bitshift numerics in triton and pytorch don't match for torch.int8


"""
this is intended to be used to unpack a [K,N] int4 tensor from a [K/2, N] uint4x2 tensor
(where the int4 and uint4x2 are represented with int8 and uint8 respectively)
where every other row of the int4 is packed with the row above it as:
uint4x2[k,n] = (8+int4[2*k,n])+(8+int4[2*k+1,n])<<4

unpack formulas:
int4[2*k,n]=(uint4x2[k,n] & 0xF) - 8
int4[2*k+1,n]=(uint4x2[k,n] >> 4) - 8

thus matching on unpack formula:
torch.mm(mat1, torch.cat((mat2 & 0xF, mat2>>4),1).reshape(mat2_mm_shape).to(mat2_dtype).sub(8))

note: although the unpack formula in pytorch and the triton kernel is designed for a uint8 mat2, the behavior
of the kernel matches the pytorch formula for all dtypes except torch.int8
where the bitwise numerics in triton do not match those in pytorch.
"""


@register_lowering_pattern(
CallFunction(
aten.mm.default,
KeywordArg("mat1"),
CallFunction(
aten.sub.Tensor,
CallFunction(
prims.convert_element_type.default,
CallFunction(
aten.reshape.default,
CallFunction(
aten.cat.default,
ListOf(
CallFunction(
aten.bitwise_and.Scalar,
KeywordArg("mat2"),
0xF,
),
CallFunction(
aten.__rshift__.Scalar,
KeywordArg("mat2"),
4,
),
),
1,
),
KeywordArg("mat2_mm_shape"),
),
KeywordArg("mat2_dtype"),
),
8,
),
),
extra_check=cuda_and_enabled_mixed_mm_and_not_int8,
)
def uint4x2_mixed_mm(match: Match, mat1, mat2, mat2_mm_shape, mat2_dtype):
return inductor.kernel.unpack_mixed_mm.tuned_uint4x2_mixed_mm(
mat1, mat2, mat2_mm_shape, mat2_dtype
)


"""
torch.mm(mat1, mat2.to(mat2_dtype))
"""
Expand All @@ -170,10 +245,7 @@ def mm_plus_mm(match: Match, mat1, mat2, mat3, mat4):
KeywordArg("mat2_dtype"),
),
),
extra_check=(
lambda match: (config.use_mixed_mm or config.force_mixed_mm)
and getattr(match.kwargs["mat1"].meta.get("val"), "is_cuda", False)
), # needs cuda
extra_check=cuda_and_enabled_mixed_mm,
)
def mixed_mm(match: Match, mat1, mat2, mat2_dtype):
return inductor.kernel.mm.tuned_mixed_mm(mat1, mat2, mat2_dtype)
Expand Down
4 changes: 3 additions & 1 deletion torch/_inductor/kernel/mm_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,14 +119,16 @@ def mm_options(config, sym_k, layout, b_prologue_cast_type=None):
)


def mm_args(mat1, mat2, *others, layout=None, out_dtype=None):
def mm_args(mat1, mat2, *others, layout=None, out_dtype=None, use_4x2_dim=False):
"""
Common arg processing for mm,bmm,addmm,etc
"""
mat1, mat2 = realize_inputs(mat1, mat2)
*b1, m, k1 = mat1.get_size()
*b2, k2, n = mat2.get_size()
b = [V.graph.sizevars.guard_equals(a, b) for a, b in zip(b1, b2)]
if use_4x2_dim:
k2 = k2 * 2
k = V.graph.sizevars.guard_equals(k1, k2)
if layout is None:
from torch._inductor.ir import FixedLayout
Expand Down
81 changes: 81 additions & 0 deletions torch/_inductor/kernel/unpack_mixed_mm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import logging

from ..select_algorithm import autotune_select_algorithm, TritonTemplate
from .mm_common import mm_args, mm_configs, mm_grid, mm_options

log = logging.getLogger(__name__)

uint4x2_mixed_mm_template = TritonTemplate(
name="uint4x2_mixed_mm",
grid=mm_grid,
source=r"""
{{def_kernel("A", "B")}}
M = {{size("A", 0)}}
N = {{size("B", 1)}}
K = {{size("A", 1)}}
stride_am = {{stride("A", 0)}}
stride_ak = {{stride("A", 1)}}
stride_bk = {{stride("B", 0)}}
stride_bn = {{stride("B", 1)}}

# based on triton.ops.matmul
pid = tl.program_id(0)
grid_m = (M + BLOCK_M - 1) // BLOCK_M
grid_n = (N + BLOCK_N - 1) // BLOCK_N

# re-order program ID for better L2 performance
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size)

rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
rk = tl.arange(0, BLOCK_K)
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
B = B + (rk[:, None]//2 * stride_bk + rbn[None, :] * stride_bn)
b_shifts = 4*(rk%2)
b_subs = 8*(1-(rk%2))

acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
for k in range(K, 0, -BLOCK_K):
if EVEN_K:
a = tl.load(A)
b = tl.load(B)
else:
a = tl.load(A, mask=rk[None, :] < k, other=0.)
b = tl.load(B, mask=rk[:, None] < k, other=0.)
b = ((b >> b_shifts[:, None]) & 0xF) - 8
b = b.to(B_PROLOGUE_CAST_TYPE)
acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
A += BLOCK_K * stride_ak
B += BLOCK_K//2 * stride_bk

# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
idx_m = rm[:, None]
idx_n = rn[None, :]
mask = (idx_m < M) & (idx_n < N)

# inductor generates a suffix
{{store_output(("idx_m", "idx_n"), "acc", "mask")}}
""",
)


def tuned_uint4x2_mixed_mm(mat1, mat2, mat2_mm_shape, mat2_dtype):
m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=None, use_4x2_dim=True)
choices = []
b_prologue_cast_type = f"tl.{mat2_dtype}".replace("torch.", "")
for config in mm_configs(m, n, k):
uint4x2_mixed_mm_template.maybe_append_choice(
choices,
(mat1, mat2),
layout,
**mm_options(config, k, layout, b_prologue_cast_type),
)
return autotune_select_algorithm("uint4x2_mixed_mm", choices, [mat1, mat2], layout)