Skip to content

Commit

Permalink
Update on "[Inductor][CPP] Pass weight dtype explicitly for cpp gemm …
Browse files Browse the repository at this point in the history
…template"


**Summary**
This PR mainly refactor 2 things:

1. Passing in weight's data type explicitly in `create_micro_gemm` as `input2.dtype`. When registering `CppMicroGemmConfig`, we will reuse `input.dtype` if `input2.dtype` is not explicitly registered.
2. Add an util function to get the output data type and compute data type from input data type.



cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
  • Loading branch information
leslie-fang-intel committed Jun 21, 2024
1 parent d360d87 commit f0e37c5
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions torch/_inductor/codegen/cpp_micro_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,10 +189,10 @@ def generate_gemm_config(
vec_isa_cls,
register_blockings,
input_dtype=torch.float,
input2_dtype=None,
output_dtype=None,
compute_dtype=None,
extra_check=None,
input2_dtype=None,
):
if output_dtype is None:
output_dtype = input_dtype
Expand Down Expand Up @@ -460,10 +460,10 @@ def check_amx_extra(config, m, n, k, alpha, num_threads):
VecAMX,
[(32, 32, 64), (48, 16, 64)],
input_dtype=torch.uint8,
input2_dtype=torch.int8,
output_dtype=torch.int32,
compute_dtype=torch.int32,
extra_check=check_amx_extra,
input2_dtype=torch.int8,
),
)
class CppMicroGemmAMX(CppMicroGemm):
Expand Down

0 comments on commit f0e37c5

Please sign in to comment.