Skip to content

Commit

Permalink
[inductor][cpp] bf16/fp16 gemm template computed with fp32 w/o epilog…
Browse files Browse the repository at this point in the history
…ue fusion (#126068)

As part of #125683, this PR adds the initial bf16/fp16 gemm template support with micro-gemm implemented with fused type casting and fp32 computation. It doesn't provide epilogue fusion support yet which will be added in the next PR.

Pull Request resolved: #126068
Approved by: https://github.com/jansel
ghstack dependencies: #124021, #126019
  • Loading branch information
jgong5 authored and pytorchmergebot committed May 23, 2024
1 parent 57108d9 commit 57c185b
Show file tree
Hide file tree
Showing 9 changed files with 360 additions and 112 deletions.
15 changes: 11 additions & 4 deletions test/inductor/test_cpu_select_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,11 @@ class TestSelectAlgorithm(TestCase):
@torch.no_grad
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
@parametrize("batch_size", (1, 2, 1000))
@parametrize("in_features", (1, 2, 1000))
@parametrize("out_features", (1, 32, 1024))
@parametrize("in_features", (1, 1000))
@parametrize("out_features", (1, 1024))
@parametrize("bias", (True, False))
@parametrize("input_3d", (True, False))
@dtypes(torch.float)
@dtypes(torch.float, torch.bfloat16, torch.half)
def test_linear_static_shapes(
self, batch_size, in_features, out_features, bias, input_3d, dtype
):
Expand All @@ -97,7 +97,14 @@ def forward(self, x):
mod = M(bias=bias).to(dtype=dtype).eval()
B = (2, batch_size) if input_3d else (batch_size,)
v = torch.randn(*B, in_features).to(dtype=dtype)
self.common(mod, (v,))
# For bfloat16 and half, we have to relax the tolerance
# due to the difference associave orders in different
# kernel implementations
atol, rtol = 1e-4, 1e-4
if dtype == torch.half or dtype == torch.bfloat16:
atol, rtol = 1e-2, 1e-2
with patch.object(select_algorithm, "VERIFY", dict(atol=atol, rtol=rtol)):
self.common(mod, (v,), atol=atol, rtol=rtol)
if (
counters["inductor"]["decompose_mm"] > 0
or counters["inductor"]["decompose_addmm"] > 0
Expand Down
3 changes: 1 addition & 2 deletions torch/_inductor/codegen/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2809,9 +2809,8 @@ def store_reduction(self, name, index, value):
return self.simd_vec

def __exit__(self, exc_type, exc_val, exc_tb):
assert self._orig_wrapper_code is not None
# Restore the wrapper_code
V.graph.wrapper_code = self._orig_wrapper_code
V.graph.wrapper_code = self._orig_wrapper_code # type: ignore[assignment]
self.exit_stack.__exit__(exc_type, exc_val, exc_tb)

def __enter__(self):
Expand Down
69 changes: 54 additions & 15 deletions torch/_inductor/codegen/cpp_gemm_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def __init__(
beta=1,
alpha=1,
):
assert layout.dtype in [torch.float, torch.bfloat16, torch.half]
super().__init__("packed_gemm", input_nodes, layout)
self.beta = beta
self.alpha = alpha
Expand Down Expand Up @@ -212,7 +213,13 @@ def cache_blocking(self) -> GemmBlocking:

@staticmethod
def add_choices(
choices, layout, input_nodes, beta=1, alpha=1, trans_w=False, input_indices=None
choices,
layout,
input_nodes,
beta=1,
alpha=1,
trans_w=False,
input_indices=None,
):
if input_indices is None:
input_indices = list(range(len(input_nodes)))
Expand All @@ -232,28 +239,58 @@ def reorder_and_filter(inputs, layout_or_out):
w_idx = input_indices[2]
return [inputs[x_idx], inputs[w_idx], inputs[inp_idx]], layout_or_out

def transpose_weight(inputs, layout_or_out):
def maybe_to_dense(inputs, layout_or_out):
new_inputs = list(inputs)
if isinstance(inputs[1], torch.Tensor):
W = inputs[1]
new_inputs[1] = W.to_dense() if W.is_mkldnn else W
return new_inputs, layout_or_out

def normalize_shapes(inputs, layout_or_out):
if not trans_w:
return inputs, layout_or_out

new_inputs = list(inputs)
X = inputs[0]
W = inputs[1]
B = inputs[2] if len(inputs) > 2 else None
if isinstance(W, ir.IRNode):
if not isinstance(W, ir.TensorBox):
W = ir.TensorBox(W)
new_inputs[1] = L.permute(W, [1, 0])
return new_inputs, layout_or_out
if trans_w:
if not isinstance(W, ir.TensorBox):
W = ir.TensorBox(W)
W = L.permute(W, [1, 0])
else:
assert isinstance(W, torch.Tensor)
new_inputs[1] = W.transpose(0, 1)
if trans_w:
assert isinstance(W, torch.Tensor)
W = W.transpose(0, 1)
if B is not None:
if isinstance(B, ir.IRNode):
if not isinstance(B, ir.TensorBox):
B = ir.TensorBox(B)
B = L.expand(B, (X.get_size()[0], B.get_size()[-1]))
else:
assert isinstance(B, torch.Tensor)
B = B.expand(X.shape[0], B.shape[-1])
new_inputs[1] = W
if B is not None:
new_inputs[2] = B
return new_inputs, layout_or_out

# TODO(jgong5): decide proper number of threads per problem size
num_threads = parallel_num_threads()
new_inputs, _ = transpose_weight(*reorder_and_filter(input_nodes, layout))
new_inputs, _ = normalize_shapes(
*maybe_to_dense(*reorder_and_filter(input_nodes, layout))
)
m, n, k, *_ = mm_args(new_inputs[0], new_inputs[1])
micro_gemm = create_micro_gemm(
"micro_gemm", m, n, k, layout.dtype, alpha=alpha, num_threads=num_threads
"micro_gemm",
m,
n,
k,
input_dtype=layout.dtype,
output_dtype=torch.float,
alpha=alpha,
num_threads=num_threads,
)
assert micro_gemm is not None
_, block_n, _ = micro_gemm.register_blocking
Expand Down Expand Up @@ -300,7 +337,9 @@ def pack_weight(inputs, layout_or_out):
return new_inputs, layout_or_out

def preprocessor(inputs, layout):
return pack_weight(*transpose_weight(*reorder_and_filter(inputs, layout)))
return pack_weight(
*normalize_shapes(*maybe_to_dense(*reorder_and_filter(inputs, layout)))
)

def postprocessor(output):
if isinstance(output, ir.TensorBox):
Expand All @@ -315,7 +354,7 @@ def postprocessor(output):
W = V.graph.constants[W_node.get_name()]
new_input_nodes[1] = W
new_input_nodes, _ = pack_weight(
*transpose_weight(new_input_nodes, layout)
*normalize_shapes(*maybe_to_dense(new_input_nodes, layout))
)
W_packed = new_input_nodes[1]
W_packed_constant = V.graph.add_tensor_constant(W_packed)
Expand Down Expand Up @@ -358,8 +397,7 @@ def render( # type: ignore[override]

template_buffer = Y
Y_is_transposed = False
# TODO(jgong5): support local accumulation
use_local_acc = False
use_local_acc = self.layout.dtype != torch.float
if epilogue_nodes:
Y = cast(ir.Buffer, epilogue_nodes[-1])
assert Y.get_name() in V.kernel.inplace_update_buffers
Expand All @@ -373,7 +411,8 @@ def render( # type: ignore[override]
self.m,
self.n,
self.k,
self.layout.dtype,
input_dtype=self.layout.dtype,
output_dtype=torch.float,
alpha=self.alpha,
num_threads=self.num_threads,
)
Expand Down
90 changes: 68 additions & 22 deletions torch/_inductor/codegen/cpp_micro_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,11 @@ def __init__(

def get_common_options(self):
return {
"torch": torch,
"kernel_name": self.name,
"input_dtype": self.input_dtype,
"output_dtype": self.output_dtype,
"compute_dtype": self.compute_dtype,
"input_t": DTYPE_TO_CPP[self.input_dtype],
"output_t": DTYPE_TO_CPP[self.output_dtype],
"compute_t": DTYPE_TO_CPP[self.compute_dtype],
Expand Down Expand Up @@ -136,6 +140,29 @@ def inner(cls):
return inner


def generate_gemm_config(
vec_isa_cls,
register_blockings,
input_dtype=torch.float,
output_dtype=None,
compute_dtype=None,
):
if output_dtype is None:
output_dtype = input_dtype
if compute_dtype is None:
compute_dtype = output_dtype
return [
CppMicroGemmConfig(
input_dtype,
output_dtype,
compute_dtype,
vec_isa_cls,
GemmBlocking(*blocking),
)
for blocking in register_blockings
]


class CppMicroGemmRef(CppMicroGemm):
"""
A reference implementation of the CppMicroGemm class with naive C++ code.
Expand Down Expand Up @@ -170,28 +197,41 @@ def codegen_define(self, kernel: CppTemplateKernel) -> str:


@register_micro_gemm(
CppMicroGemmConfig(
torch.float32, torch.float32, torch.float32, VecAVX512, GemmBlocking(8, 48, 1)
*generate_gemm_config(
VecAVX512, [(8, 48, 1), (8, 32, 1), (16, 16, 1)], input_dtype=torch.float
),
CppMicroGemmConfig(
torch.float32, torch.float32, torch.float32, VecAVX512, GemmBlocking(8, 32, 1)
*generate_gemm_config(
VecAVX512,
[(8, 48, 1), (8, 32, 1), (16, 16, 1)],
input_dtype=torch.bfloat16,
output_dtype=torch.float,
),
CppMicroGemmConfig(
torch.float32, torch.float32, torch.float32, VecAVX512, GemmBlocking(16, 16, 1)
*generate_gemm_config(
VecAVX512,
[(8, 48, 1), (8, 32, 1), (16, 16, 1)],
input_dtype=torch.half,
output_dtype=torch.float,
),
CppMicroGemmConfig(
torch.float32, torch.float32, torch.float32, VecAVX2, GemmBlocking(4, 24, 1)
*generate_gemm_config(
VecAVX2, [(4, 24, 1), (4, 16, 1), (8, 8, 1)], input_dtype=torch.float
),
CppMicroGemmConfig(
torch.float32, torch.float32, torch.float32, VecAVX2, GemmBlocking(4, 16, 1)
*generate_gemm_config(
VecAVX2,
[(4, 24, 1), (4, 16, 1), (8, 8, 1)],
input_dtype=torch.bfloat16,
output_dtype=torch.float,
),
CppMicroGemmConfig(
torch.float32, torch.float32, torch.float32, VecAVX2, GemmBlocking(8, 8, 1)
*generate_gemm_config(
VecAVX2,
[(4, 24, 1), (4, 16, 1), (8, 8, 1)],
input_dtype=torch.half,
output_dtype=torch.float,
),
)
class CppMicroGemmFP32Vec(CppMicroGemm):
"""
This class generates the code for fp32 micro gemm using vec instructions.
This class generates the code for micro gemm using fp32 vec instructions for compute.
It supports input types of torch.float, torch.bfloat16, and torch.half with fp32 output.
"""

TEMPLATE_ENTRY = r"""
Expand Down Expand Up @@ -239,22 +279,23 @@ class CppMicroGemmFP32Vec(CppMicroGemm):
TEMPLATE_KERNEL = r"""
template <int64_t BLOCK_M, int64_t BLOCK_N, bool accum>
inline void {{kernel_name}}_kernel(
const float* __restrict__ A,
const float* __restrict__ B,
float* __restrict__ C,
const {{input_t}}* __restrict__ A,
const {{input_t}}* __restrict__ B,
{{output_t}}* __restrict__ C,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc
) {
using Vectorized = at::vec::Vectorized<float>;
using Vectorized = at::vec::Vectorized<{{compute_t}}>;
using VectorizedIn = at::vec::Vectorized<{{input_t}}>;
constexpr auto VLEN = Vectorized::size();
constexpr auto ROWS = BLOCK_M;
constexpr auto COLS = BLOCK_N / VLEN;
Vectorized va;
at::vec::VectorizedN<float, COLS> vb;
at::vec::VectorizedN<float, ROWS*COLS> vc;
at::vec::VectorizedN<{{compute_t}}, COLS> vb;
at::vec::VectorizedN<{{compute_t}}, ROWS*COLS> vc;
auto loadc = [&](auto i) {
if constexpr (accum) {
Expand All @@ -273,14 +314,19 @@ class CppMicroGemmFP32Vec(CppMicroGemm):
if constexpr (col == 0) {
{%- if alpha != 1 %}
va = Vectorized(A[row * lda + k] * {{alpha}});
va = Vectorized(static_cast<{{compute_t}}>(A[row * lda + k]) * {{alpha}});
{%- else %}
va = Vectorized(A[row * lda + k]);
va = Vectorized(static_cast<{{compute_t}}>(A[row * lda + k]));
{%- endif %}
}
if constexpr (row == 0) {
{%- if input_dtype == torch.bfloat16 or input_dtype == torch.float16 %}
auto b = VectorizedIn::loadu(B + k * ldb + col * VLEN, VLEN);
vb[col] = at::vec::convert<{{compute_t}}>(b);
{%- else %}
vb[col] = Vectorized::loadu(B + k * ldb + col * VLEN);
{%- endif %}
}
constexpr int idx = row * COLS + col;
Expand Down Expand Up @@ -349,7 +395,7 @@ def create_from_config(cls, config: CppMicroGemmConfig):
if output_dtype is None:
output_dtype = input_dtype
if compute_dtype is None:
compute_dtype = input_dtype
compute_dtype = output_dtype
if num_threads < 0:
num_threads = parallel_num_threads()
vec_isa = pick_vec_isa()
Expand Down

1 comment on commit 57c185b

@pytorchmergebot
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted #124021 on behalf of https://github.com/huydhn due to Sorry for reverting your change but I think it has a land race and failing in trunk https://hud.pytorch.org/pytorch/pytorch/commit/2ac33a9f663269e6060246337c776a20c3b7c858 (comment)

Please sign in to comment.