Skip to content

Commit

Permalink
[inductor][cpp] BF16 AMX micro-gemm support (#127195)
Browse files Browse the repository at this point in the history
This PR adds the intrinsics based micro-gemm for BF16 using Advanced Matrix eXtension (AMX) instructions available in Intel 4th and 5th Xeon processors. A compilation check is added to `codecache.py` to check the validity of the compiler support. Also, since AMX requires an initialization in the Linux kernel to extra register states, an initialization function is added to do that and triggered via `codecache.py`.

Performance speedups with >=10% on BF16 AMP, max_autotune vs. no autotune, measured on Intel(R) Xeon(R) Platinum 8488C:
Static shapes
Single-threaded
| Model Family | Model Name | Speedup |
|--------------|------------|---------|
| timm_models | mixer_b16_224 | 1.54 |
| timm_models | convit_base | 1.53 |
| huggingface | MobileBertForQuestionAnswering | 1.52 |
| torchbench | fastNLP_Bert | 1.44 |
| torchbench | llama | 1.33 |
| timm_models | swin_base_patch4_window7_224 | 1.31 |
| torchbench | dlrm | 1.28 |
| torchbench | timm_vision_transformer_large | 1.28 |
| huggingface | MobileBertForMaskedLM | 1.27 |
| timm_models | vit_base_patch16_224 | 1.26 |
| timm_models | beit_base_patch16_224 | 1.23 |
| timm_models | jx_nest_base | 1.21 |
| torchbench | pyhpc_equation_of_state | 1.18 |
| huggingface | Speech2Text2ForCausalLM | 1.15 |
| timm_models | pit_b_224 | 1.14 |
| timm_models | twins_pcpvt_base | 1.14 |
| torchbench | maml_omniglot | 1.1 |
| timm_models | botnet26t_256 | 1.1 |

Multi-threaded
| Model Family | Model Name | Speedup |
|--------------|------------|---------|
| torchbench | BERT_pytorch | 1.35 |
| torchbench | lennard_jones | 2.43 |
| torchbench | hf_Albert | 1.35 |
| torchbench | hf_T5 | 1.34 |
| torchbench | soft_actor_critic | 1.34 |
| torchbench | fastNLP_Bert | 1.28 |
| huggingface | LayoutLMForSequenceClassification | 1.26 |
| torchbench | llama | 1.24 |
| huggingface | GPT2ForSequenceClassification | 1.19 |
| torchbench | hf_Bart | 1.17 |
| torchbench | hf_Bert_large | 1.16 |
| torchbench | hf_GPT2 | 1.16 |
| timm_models | gmixer_24_224 | 1.16 |
| torchbench | hf_GPT2_large | 1.15 |
| torchbench | maml_omniglot | 1.14 |
| torchbench | hf_Bert | 1.13 |
| torchbench | hf_DistilBert | 1.13 |
| torchbench | hf_T5_large | 1.12 |
| huggingface | MT5ForConditionalGeneration | 1.11 |

Dynamic shapes
Single-threaded
| Model Family | Model Name | Speedup |
|--------------|------------|-------|
| timm_models | mixer_b16_224 | 1.52 |
| timm_models | convit_base | 1.5 |
| huggingface | MobileBertForQuestionAnswering | 1.49 |
| torchbench | fastNLP_Bert | 1.42 |
| torchbench | timm_vision_transformer_large | 1.28 |
| timm_models | swin_base_patch4_window7_224 | 1.27 |
| torchbench | llama | 1.26 |
| huggingface | MobileBertForMaskedLM | 1.25 |
| timm_models | vit_base_patch16_224 | 1.25 |
| timm_models | beit_base_patch16_224 | 1.24 |
| timm_models | jx_nest_base | 1.2 |
| torchbench | dlrm | 1.19 |
| timm_models | pit_b_224 | 1.13 |
| timm_models | twins_pcpvt_base | 1.13 |
| torchbench | hf_Bert_large | 1.12 |
| torchbench | hf_BigBird | 1.11 |
| huggingface | Speech2Text2ForCausalLM | 1.11 |
| timm_models | eca_botnext26ts_256 | 1.11 |
| timm_models | botnet26t_256 | 1.1 |

Multi-threaded
| Model Family | Model Name | Speedup |
|--------------|------------|-------|
| torchbench | BERT_pytorch | 1.18 |
| torchbench | lennard_jones | 2.18 |
| torchbench | hf_Albert | 1.37 |
| torchbench | soft_actor_critic | 1.31 |
| huggingface | GPT2ForSequenceClassification | 1.29 |
| torchbench | hf_T5 | 1.28 |
| torchbench | fastNLP_Bert | 1.27 |
| torchbench | hf_Bart | 1.21 |
| torchbench | hf_Bert_large | 1.19 |
| torchbench | hf_T5_large | 1.19 |
| torchbench | hf_Bert | 1.16 |
| torchbench | hf_GPT2 | 1.16 |
| huggingface | CamemBert | 1.16 |
| torchbench | hf_GPT2_large | 1.13 |
| torchbench | functorch_maml_omniglot | 1.12 |
| huggingface | BertForMaskedLM | 1.12 |
| huggingface | MT5ForConditionalGeneration | 1.12 |
| torchbench | hf_DistilBert | 1.11 |
| timm_models | mixnet_l | 1.11 |
| timm_models | tf_mixnet_l | 1.11 |

No perf regressions.

Pull Request resolved: #127195
Approved by: https://github.com/jansel
  • Loading branch information
jgong5 authored and pytorchmergebot committed Jun 21, 2024
1 parent 632910e commit 914d3ca
Show file tree
Hide file tree
Showing 12 changed files with 545 additions and 42 deletions.
47 changes: 47 additions & 0 deletions aten/src/ATen/cpu/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
#if !defined(__s390x__ ) && !defined(__powerpc__)
#include <cpuinfo.h>
#endif
#if defined(__linux__)
#include <sys/syscall.h>
#include <unistd.h>
#endif

namespace at::cpu {
bool is_cpu_support_avx2() {
Expand All @@ -28,4 +32,47 @@ bool is_cpu_support_avx512_vnni() {
#endif
}

bool is_cpu_support_amx_tile() {
#if !defined(__s390x__) && !defined(__powerpc__)
return cpuinfo_initialize() && cpuinfo_has_x86_amx_tile();
#else
return false;
#endif
}

bool init_amx() {
if (!is_cpu_support_amx_tile()) {
return false;
}

#if defined(__linux__) && !defined(__ANDROID__)
#define XFEATURE_XTILECFG 17
#define XFEATURE_XTILEDATA 18
#define XFEATURE_MASK_XTILECFG (1 << XFEATURE_XTILECFG)
#define XFEATURE_MASK_XTILEDATA (1 << XFEATURE_XTILEDATA)
#define XFEATURE_MASK_XTILE (XFEATURE_MASK_XTILECFG | XFEATURE_MASK_XTILEDATA)

#define ARCH_GET_XCOMP_PERM 0x1022
#define ARCH_REQ_XCOMP_PERM 0x1023

unsigned long bitmask = 0;
// Request permission to use AMX instructions
long rc = syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA);
if (rc) {
return false;
}
// Check if the system supports AMX instructions
rc = syscall(SYS_arch_prctl, ARCH_GET_XCOMP_PERM, &bitmask);
if (rc) {
return false;
}
if (bitmask & XFEATURE_MASK_XTILE) {
return true;
}
return false;
#else
return true;
#endif
}

} // namespace at::cpu
6 changes: 6 additions & 0 deletions aten/src/ATen/cpu/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,10 @@ TORCH_API bool is_cpu_support_avx512();
// Detect if CPU support Vector Neural Network Instruction.
TORCH_API bool is_cpu_support_avx512_vnni();

// Detect if CPU support Advanced Matrix Extension.
TORCH_API bool is_cpu_support_amx_tile();

// Enable the system to use AMX instructions.
TORCH_API bool init_amx();

} // namespace at::cpu
4 changes: 2 additions & 2 deletions test/inductor/test_cpu_repro.py
Original file line number Diff line number Diff line change
Expand Up @@ -1588,8 +1588,8 @@ def fn(x):
)
@patch("torch.cuda.is_available", lambda: False)
def test_auto_simd(self):
vec_avx512 = codecache.supported_vec_isa_list[0]
vec_avx2 = codecache.supported_vec_isa_list[1]
vec_avx512 = codecache.supported_vec_isa_list[1]
vec_avx2 = codecache.supported_vec_isa_list[2]
self.assertTrue(vec_avx512.bit_width() == 512)
self.assertTrue(vec_avx2.bit_width() == 256)
self.assertTrue(vec_avx512.nelements() == 16)
Expand Down
33 changes: 33 additions & 0 deletions test/inductor/test_cpu_select_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch._inductor.config as inductor_config
import torch._inductor.select_algorithm as select_algorithm
from torch._dynamo.utils import counters
from torch._inductor.codecache import VecAMX
from torch._inductor.test_case import run_tests, TestCase
from torch.testing._internal.common_device_type import (
dtypes,
Expand Down Expand Up @@ -333,6 +334,37 @@ def forward(self, x):
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1)

@inductor_config.patch({"freezing": True})
@patches
@torch.no_grad
@parametrize("bias", (True, False))
def test_linear_amx(self, bias):
batch_size = 1024
in_features = 1024
out_features = 1024
dtype = torch.bfloat16

class M(torch.nn.Module):
def __init__(self, bias):
super().__init__()
self.linear = torch.nn.Linear(in_features, out_features, bias)

def forward(self, x):
return self.linear(x)

counters.clear()
v = torch.randn(batch_size, in_features).to(dtype=dtype)
mod = M(bias=bias).to(dtype=dtype).eval()
atol, rtol = 1e-2, 1e-2
with patch.object(select_algorithm, "VERIFY", dict(atol=atol, rtol=rtol)):
self.common(mod, (v,), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
vec_amx = VecAMX()
if vec_amx:
self.assertTrue(counters["inductor"]["cpp_micro_gemm_amx_counter"] > 0)
else:
self.assertEqual(counters["inductor"]["cpp_micro_gemm_amx_counter"], 0)


@dynamo_config.patch({"dynamic_shapes": True, "assume_static_by_default": False})
class _DynamicShapesTestBase(TestCase):
Expand All @@ -351,6 +383,7 @@ class TestSelectAlgorithmDynamicShapes(_DynamicShapesTestBase):
test_linear_with_unary_binary_dynamic_shapes = (
TestSelectAlgorithm.test_linear_with_unary_binary
)
test_linear_amx_dynamic_shapes = TestSelectAlgorithm.test_linear_amx


instantiate_device_type_tests(TestSelectAlgorithm, globals(), only_for="cpu")
Expand Down
2 changes: 2 additions & 0 deletions torch/_C/_cpu.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@ from torch.types import _bool
def _is_cpu_support_avx2() -> _bool: ...
def _is_cpu_support_avx512() -> _bool: ...
def _is_cpu_support_avx512_vnni() -> _bool: ...
def _is_cpu_support_amx_tile() -> _bool: ...
def _init_amx() -> _bool: ...
4 changes: 4 additions & 0 deletions torch/_dynamo/trace_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,8 @@
"torch._C._cpu._is_cpu_support_avx2",
"torch._C._cpu._is_cpu_support_avx512",
"torch._C._cpu._is_cpu_support_avx512_vnni",
"torch._C._cpu._is_cpu_support_amx_tile",
"torch._C._cpu._init_amx",
"torch._C._crash_if_aten_asan",
"torch._C._crash_if_csrc_asan",
"torch._C._crash_if_csrc_ubsan",
Expand Down Expand Up @@ -2423,6 +2425,8 @@
"torch.cpu._is_cpu_support_avx2",
"torch.cpu._is_cpu_support_avx512",
"torch.cpu._is_cpu_support_avx512_vnni",
"torch.cpu._is_cpu_support_amx_tile",
"torch.cpu._init_amx",
"torch.cpu.current_device",
"torch.cpu.current_stream",
"torch.cpu.device_count",
Expand Down
67 changes: 56 additions & 11 deletions torch/_inductor/codecache.py
Original file line number Diff line number Diff line change
Expand Up @@ -1336,18 +1336,11 @@ def build_arch_flags(self) -> str:
def __hash__(self) -> int:
return hash(str(self))

@functools.lru_cache(None) # noqa: B019
def __bool__(self) -> bool:
def check_build(self, code) -> bool:
from torch._inductor.cpp_builder import CppBuilder, CppTorchOptions

if config.cpp.vec_isa_ok is not None:
return config.cpp.vec_isa_ok

if config.is_fbcode():
return True

key, input_path = write(
VecISA._avx_code,
code,
"cpp",
extra=_get_isa_dry_compile_fingerprint(self._arch_flags),
)
Expand Down Expand Up @@ -1385,6 +1378,16 @@ def __bool__(self) -> bool:

return True

@functools.lru_cache(None) # noqa: B019
def __bool__(self) -> bool:
if config.cpp.vec_isa_ok is not None:
return config.cpp.vec_isa_ok

if config.is_fbcode():
return True

return self.check_build(VecISA._avx_code)


@dataclasses.dataclass
class VecNEON(VecISA):
Expand Down Expand Up @@ -1418,6 +1421,46 @@ def __str__(self) -> str:
__hash__: Callable[[VecISA], Any] = VecISA.__hash__


@dataclasses.dataclass
class VecAMX(VecAVX512):
_arch_flags = VecAVX512._arch_flags + " -mamx-tile -mamx-bf16 -mamx-int8"

def __str__(self) -> str:
return super().__str__() + " amx_tile"

__hash__: Callable[[VecISA], Any] = VecISA.__hash__

_amx_code = """
#include <cstdint>
#include <immintrin.h>
struct amx_tilecfg {
uint8_t palette_id;
uint8_t start_row;
uint8_t reserved_0[14];
uint16_t colsb[16];
uint8_t rows[16];
};
extern "C" void __amx_chk_kernel() {
amx_tilecfg cfg = {0};
_tile_loadconfig(&cfg);
_tile_zero(0);
_tile_dpbf16ps(0, 1, 2);
_tile_dpbusd(0, 1, 2);
}
"""

@functools.lru_cache(None) # noqa: B019
def __bool__(self) -> bool:
if super().__bool__():
if config.is_fbcode():
return False
if self.check_build(VecAMX._amx_code) and torch.cpu._init_amx():
return True
return False


@dataclasses.dataclass
class VecAVX2(VecISA):
_bit_width = 256
Expand Down Expand Up @@ -1483,15 +1526,17 @@ def _check_and_append_supported_isa(

avx2 = torch.cpu._is_cpu_support_avx2()
avx512 = torch.cpu._is_cpu_support_avx512()
amx_tile = torch.cpu._is_cpu_support_amx_tile()

_check_and_append_supported_isa(supported_isa, avx2, "avx2")
_check_and_append_supported_isa(supported_isa, avx512, "avx512")
_check_and_append_supported_isa(supported_isa, amx_tile, "amx_tile")

return supported_isa


invalid_vec_isa = InvalidVecISA()
supported_vec_isa_list = [VecAVX512(), VecAVX2(), VecNEON()]
supported_vec_isa_list = [VecAMX(), VecAVX512(), VecAVX2(), VecNEON()]


# Cache the cpuinfo to avoid I/O overhead. Meanwhile, the cpuinfo content
Expand Down Expand Up @@ -1528,7 +1573,7 @@ def valid_vec_isa_list() -> List[VecISA]:
"""
_cpu_supported_x86_isa = x86_isa_checker()
for isa in supported_vec_isa_list:
if str(isa) in _cpu_supported_x86_isa and isa:
if all(flag in _cpu_supported_x86_isa for flag in str(isa).split()) and isa:
isa_list.append(isa)

return isa_list
Expand Down
22 changes: 20 additions & 2 deletions torch/_inductor/codegen/cpp_gemm_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@

import torch
import torch.utils
from ..._dynamo.utils import counters
from .. import ir, lowering as L

from ..kernel.mm_common import mm_args
from ..select_algorithm import DataProcessorTemplateWrapper
from ..utils import cache_on_self, has_free_symbols, parallel_num_threads
from ..virtualized import V
from .cpp_micro_gemm import create_micro_gemm
from .cpp_micro_gemm import CppMicroGemmAMX, create_micro_gemm, LayoutType
from .cpp_template import CppTemplate

from .cpp_template_kernel import CppTemplateKernel
Expand Down Expand Up @@ -84,15 +85,18 @@
int64_t k_block_start = 0;
int64_t k_block_end = K0_blocks;
{%- endif %}
{{ micro_gemm.codegen_init(kernel) }}
for (int64_t mc = m_block_start; mc < m_block_end; mc += Mc_blocks) {
const int64_t m_start = mc * M0;
const int64_t m_end = std::min((mc + Mc_blocks) * M0, M);
const int64_t m_size = m_end - m_start;
{%- if use_local_acc %}
{{ kernel.define_buffer(acc_buf_name, ["m_end - m_start", "N0"]) }}
{%- endif %}
for (int64_t nc = n_block_start; nc < n_block_end; ++nc) {
const int64_t n_start = nc * N0;
const int64_t n_size = N0;
{%- if use_local_acc %}
{{ kernel.define_buffer(acc_buf_name, ["m_end - m_start", "N0"]) }}
{%- set acc = kernel.local_buffers[acc_buf_name] %}
{%- else %}
{%- set acc = kernel.slice_nd(GemmOut, [("m_start", "m_end"), ("n_start", "n_start + N0")]) %}
Expand Down Expand Up @@ -128,6 +132,7 @@
}}
}
}
{{ micro_gemm.codegen_finalize(kernel) }}
}
}
"""
Expand Down Expand Up @@ -332,6 +337,17 @@ def pack_weight(inputs, layout_or_out):
blocked_w = (
W.reshape(k, n // block_n, block_n).transpose(0, 1).contiguous()
)
if micro_gemm.get_b_layout() != LayoutType.NORMAL:
assert (
micro_gemm.get_b_layout() == LayoutType.VNNI2
), "We only support VNNI2 for now"
assert k % 2 == 0, "k should be even for VNNI2 layout"
blocked_w = (
blocked_w.view(n // block_n, k // 2, 2, block_n)
.transpose(-1, -2)
.contiguous()
.view(n // block_n, k, block_n)
)
# normalize stride to be "contiguous_strides" per size
# this avoids the problems in L.view during template codegen
new_stride = [1]
Expand Down Expand Up @@ -462,6 +478,8 @@ def render( # type: ignore[override]
)
assert micro_gemm is not None
assert self.register_blocking == micro_gemm.register_blocking
if isinstance(micro_gemm, CppMicroGemmAMX):
counters["inductor"]["cpp_micro_gemm_amx_counter"] += 1

options = dict(
X=X,
Expand Down
Loading

0 comments on commit 914d3ca

Please sign in to comment.