From a2bbced0a17e3823778016653f305a0a1f09dd05 Mon Sep 17 00:00:00 2001 From: DefTruth <31974251+DefTruth@users.noreply.github.com> Date: Mon, 30 Sep 2024 11:16:46 +0800 Subject: [PATCH 1/5] Update README.md --- hgemm/README.md | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/hgemm/README.md b/hgemm/README.md index 06f13ce0..b5b90cd6 100755 --- a/hgemm/README.md +++ b/hgemm/README.md @@ -13,6 +13,24 @@ - [X] hgemm_t_8x8_sliced_k_f16x8_pack_bcf_kernel(bank conflicts reduce, pack) - [X] PyTorch bindings +## 共享内存 Bank Conflicts + +含义:在访问shared memory时,因多个线程读写同一个Bank中的不同数据地址时,导致shared memory 并发读写 退化 成顺序读写的现象叫做Bank Conflict; + +![](https://github.com/PaddleJitLab/CUDATutorial/blob/develop/docs/09_optimize_reduce/02_bank_conflict/images/ef322be7c3e5b6b9be69d2b90e88083f50569a58a97129f348e483b946ab4edf.png) + +SM调度单位为一个warp(一个warp内32个Thread),shared_memory 可以 被一个warp中的所有(32个)线程进行访问,shared_memory 映射到大小相等的32个Bank上,Bank的数据读取带宽为32bit / cycle (4 bytes),因此,主要需要考虑一个Warp内32线程的访问共享内存时的bank冲突。 +对于多个线程读取同一个Bank数据时(不同地址),硬件把内存读写请求,拆分成 conflict-free requests,进行顺序读写,此时将会触发多次内存事务。特别地,当一个warp中的所有线程读写同一个地址时,会触发broadcast机制,此时不会退化成顺序读写。上面提到触发broadcast机制的条件是all threads acess same address,但在翻阅cuda-c-programming-guide以及最新版本的[NVProfGuide](https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html) 时,发现只要是多个thread 读写就会触发broadcast(不需要All)。 + +- 多个线程读同一个数据时,仅有一个线程读,然后broadcast到其他线程 +- 多个线程写同一个数据时,仅会有一个线程写成功(不过这里没有提及是否会将写操作执行多次(即a. 多个线程写入,最后一个线程随机写完; or b. 随机挑选一个线程执行写入),具体流程存疑) + +## 参考文献 + +- [CUDA编程概念】一、什么是bank conflict?](https://zhuanlan.zhihu.com/p/659142274) +- [解决 bank conflict](https://github.com/PaddleJitLab/CUDATutorial/blob/develop/docs/09_optimize_reduce/02_bank_conflict/README.md) +- [Bank Conflict free 的几种方式](https://zhuanlan.zhihu.com/p/722286440) + ## 测试 ```bash From 5e48158e0892df43a59d70478aeb40a56b78b819 Mon Sep 17 00:00:00 2001 From: DefTruth <31974251+DefTruth@users.noreply.github.com> Date: Mon, 30 Sep 2024 11:38:50 +0800 Subject: [PATCH 2/5] Update README.md --- hgemm/README.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/hgemm/README.md b/hgemm/README.md index b5b90cd6..61b94fbb 100755 --- a/hgemm/README.md +++ b/hgemm/README.md @@ -23,13 +23,16 @@ SM调度单位为一个warp(一个warp内32个Thread),shared_memory 可以 对于多个线程读取同一个Bank数据时(不同地址),硬件把内存读写请求,拆分成 conflict-free requests,进行顺序读写,此时将会触发多次内存事务。特别地,当一个warp中的所有线程读写同一个地址时,会触发broadcast机制,此时不会退化成顺序读写。上面提到触发broadcast机制的条件是all threads acess same address,但在翻阅cuda-c-programming-guide以及最新版本的[NVProfGuide](https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html) 时,发现只要是多个thread 读写就会触发broadcast(不需要All)。 - 多个线程读同一个数据时,仅有一个线程读,然后broadcast到其他线程 -- 多个线程写同一个数据时,仅会有一个线程写成功(不过这里没有提及是否会将写操作执行多次(即a. 多个线程写入,最后一个线程随机写完; or b. 随机挑选一个线程执行写入),具体流程存疑) +- 多个线程写同一个数据时,仅会有一个线程写成功 + +[Using Shared Memory in CUDA C/C++](https://developer.nvidia.com/blog/using-shared-memory-cuda-cc/)中指出,我们还可以通过 `cudaDeviceSetSharedMemConfig()` 函数设置默认Bank Size(默认为4 bytes)来避免bank conflicts,可设置为cudaSharedMemBankSizeFourByte或者cudaSharedMemBankSizeEightByte。对于某些场景来说,设置cudaSharedMemBankSizeEightByte或许更合格式。 ## 参考文献 - [CUDA编程概念】一、什么是bank conflict?](https://zhuanlan.zhihu.com/p/659142274) - [解决 bank conflict](https://github.com/PaddleJitLab/CUDATutorial/blob/develop/docs/09_optimize_reduce/02_bank_conflict/README.md) - [Bank Conflict free 的几种方式](https://zhuanlan.zhihu.com/p/722286440) +- [Using Shared Memory in CUDA C/C++](https://developer.nvidia.com/blog/using-shared-memory-cuda-cc/) ## 测试 From 3da0c8c364faad1f679210255c710730c7f52583 Mon Sep 17 00:00:00 2001 From: DefTruth <31974251+DefTruth@users.noreply.github.com> Date: Mon, 30 Sep 2024 11:41:12 +0800 Subject: [PATCH 3/5] Update sgemm.cu --- sgemm/sgemm.cu | 70 +++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 64 insertions(+), 6 deletions(-) diff --git a/sgemm/sgemm.cu b/sgemm/sgemm.cu index 8b49fcd3..b646169e 100644 --- a/sgemm/sgemm.cu +++ b/sgemm/sgemm.cu @@ -159,7 +159,7 @@ __global__ void sgemm_t_8x8_sliced_k_f32x4_kernel(float* a, float* b, float* c, } } -template +template __global__ void sgemm_t_8x8_sliced_k_f32x4_bcf_kernel( float* a, float* b, float* c, const int M, const int N, const int K) { @@ -169,8 +169,8 @@ __global__ void sgemm_t_8x8_sliced_k_f32x4_bcf_kernel( const int ty = threadIdx.y; const int tid = ty * blockDim.x + tx; - __shared__ float s_a[BK][BM]; - __shared__ float s_b[BK][BN]; + __shared__ float s_a[BK][BM + OFFSET]; + __shared__ float s_b[BK][BN + OFFSET]; // __shared__ float s_a[BK][BM + 4]; // __shared__ float s_b[BK][BN + 4]; @@ -334,7 +334,7 @@ __global__ void sgemm_t_8x8_sliced_k_f32x4_bcf_kernel( } } -template +template __global__ void sgemm_t_8x8_sliced_k_f32x4_bcf_dbuf_kernel( float* a, float* b, float* c, const int M, const int N, const int K) { @@ -344,8 +344,8 @@ __global__ void sgemm_t_8x8_sliced_k_f32x4_bcf_dbuf_kernel( const int ty = threadIdx.y; const int tid = ty * blockDim.x + tx; - __shared__ float s_a[2][BK][BM]; - __shared__ float s_b[2][BK][BN]; + __shared__ float s_a[2][BK][BM + OFFSET]; + __shared__ float s_b[2][BK][BN + OFFSET]; float r_load_a[TM/2]; float r_load_b[TN/2]; @@ -592,6 +592,34 @@ void sgemm_t_8x8_sliced_k_f32x4_bcf(torch::Tensor a, torch::Tensor b, torch::Ten ); } +void sgemm_t_8x8_sliced_k_f32x4_bcf_offset(torch::Tensor a, torch::Tensor b, torch::Tensor c) { + CHECK_TORCH_TENSOR_DTYPE(a, torch::kFloat32) + CHECK_TORCH_TENSOR_DTYPE(b, torch::kFloat32) + CHECK_TORCH_TENSOR_DTYPE(c, torch::kFloat32) + const int M = a.size(0); + const int K = a.size(1); + const int N = b.size(1); + CHECK_TORCH_TENSOR_SHAPE(a, M, K) + CHECK_TORCH_TENSOR_SHAPE(b, K, N) + CHECK_TORCH_TENSOR_SHAPE(c, M, N) + constexpr int BM = 128; + constexpr int BN = 128; + constexpr int BK = 8; + constexpr int TM = 8; + constexpr int TN = 8; + constexpr int OFFSET = 4; + + dim3 block(BN/TN, BM/TM); + dim3 grid((N + BN - 1) / BN, (M + BM - 1) / BM); + + sgemm_t_8x8_sliced_k_f32x4_bcf_kernel<<>>( + reinterpret_cast(a.data_ptr()), + reinterpret_cast(b.data_ptr()), + reinterpret_cast(c.data_ptr()), + M, N, K + ); +} + void sgemm_t_8x8_sliced_k_f32x4_bcf_dbuf(torch::Tensor a, torch::Tensor b, torch::Tensor c) { CHECK_TORCH_TENSOR_DTYPE(a, torch::kFloat32) CHECK_TORCH_TENSOR_DTYPE(b, torch::kFloat32) @@ -619,10 +647,40 @@ void sgemm_t_8x8_sliced_k_f32x4_bcf_dbuf(torch::Tensor a, torch::Tensor b, torch ); } +void sgemm_t_8x8_sliced_k_f32x4_bcf_dbuf_offset(torch::Tensor a, torch::Tensor b, torch::Tensor c) { + CHECK_TORCH_TENSOR_DTYPE(a, torch::kFloat32) + CHECK_TORCH_TENSOR_DTYPE(b, torch::kFloat32) + CHECK_TORCH_TENSOR_DTYPE(c, torch::kFloat32) + const int M = a.size(0); + const int K = a.size(1); + const int N = b.size(1); + CHECK_TORCH_TENSOR_SHAPE(a, M, K) + CHECK_TORCH_TENSOR_SHAPE(b, K, N) + CHECK_TORCH_TENSOR_SHAPE(c, M, N) + constexpr int BM = 128; + constexpr int BN = 128; + constexpr int BK = 8; + constexpr int TM = 8; + constexpr int TN = 8; + constexpr int OFFSET = 4; + + dim3 block(BN/TN, BM/TM); + dim3 grid((N + BN - 1) / BN, (M + BM - 1) / BM); + + sgemm_t_8x8_sliced_k_f32x4_bcf_dbuf_kernel<<>>( + reinterpret_cast(a.data_ptr()), + reinterpret_cast(b.data_ptr()), + reinterpret_cast(c.data_ptr()), + M, N, K + ); +} + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { TORCH_BINDING_COMMON_EXTENSION(sgemm_naive_f32) TORCH_BINDING_COMMON_EXTENSION(sgemm_sliced_k_f32) TORCH_BINDING_COMMON_EXTENSION(sgemm_t_8x8_sliced_k_f32x4) TORCH_BINDING_COMMON_EXTENSION(sgemm_t_8x8_sliced_k_f32x4_bcf) + TORCH_BINDING_COMMON_EXTENSION(sgemm_t_8x8_sliced_k_f32x4_bcf_offset) TORCH_BINDING_COMMON_EXTENSION(sgemm_t_8x8_sliced_k_f32x4_bcf_dbuf) + TORCH_BINDING_COMMON_EXTENSION(sgemm_t_8x8_sliced_k_f32x4_bcf_dbuf_offset) } From 4470584b20ef86e711efbfc86296abade17fb3bf Mon Sep 17 00:00:00 2001 From: DefTruth <31974251+DefTruth@users.noreply.github.com> Date: Mon, 30 Sep 2024 11:41:49 +0800 Subject: [PATCH 4/5] Update sgemm.py --- sgemm/sgemm.py | 30 ++++++++++++++++++++---------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/sgemm/sgemm.py b/sgemm/sgemm.py index d1b2d8d2..94e4e29c 100644 --- a/sgemm/sgemm.py +++ b/sgemm/sgemm.py @@ -53,7 +53,7 @@ def run_benchmark(perf_func: callable, out_val = out.flatten().detach().cpu().numpy().tolist()[:3] out_val = [round(v, 8) for v in out_val] out_val = [f"{v:<12}" for v in out_val] - print(f"{out_info:>20}: {out_val}, time:{mean_time:.6f}ms") + print(f"{out_info:>27}: {out_val}, time:{mean_time:.6f}ms") if show_all: print(out) return out.clone(), mean_time @@ -63,15 +63,25 @@ def run_benchmark(perf_func: callable, Ks = [1024, 2048] MNKs = [(M, N, K) for M in Ms for N in Ns for K in Ks] for (M, N, K) in MNKs: - print("-" * 85) - print(" " * 35 + f"M={M}, N={N}, K={K}") + print("-" * 100) + print(" " * 45 + f"M={M}, N={N}, K={K}") a = torch.randn((M, K)).cuda().float().contiguous() b = torch.randn((K, N)).cuda().float().contiguous() c = torch.randn((M, N)).cuda().float().contiguous() - run_benchmark(lib.sgemm_naive_f32, a, b, "f32", c) - run_benchmark(lib.sgemm_sliced_k_f32, a, b, "f32(sk)", c) - run_benchmark(lib.sgemm_t_8x8_sliced_k_f32x4, a, b, "f32x4(t8x8sk)", c) - run_benchmark(lib.sgemm_t_8x8_sliced_k_f32x4_bcf, a, b, "f32x4(t8x8bcf)", c) - run_benchmark(lib.sgemm_t_8x8_sliced_k_f32x4_bcf_dbuf, a, b, "f32x4(t8x8dbuf)", c) - run_benchmark(partial(torch.matmul, out=c), a, b, "f32_th") - print("-" * 85) + run_benchmark(lib.sgemm_naive_f32, + a, b, "f32", c) + run_benchmark(lib.sgemm_sliced_k_f32, + a, b, "f32(sk)", c) + run_benchmark(lib.sgemm_t_8x8_sliced_k_f32x4, + a, b, "f32x4(t8x8sk)", c) + run_benchmark(lib.sgemm_t_8x8_sliced_k_f32x4_bcf, + a, b, "f32x4(t8x8bcf)", c) + run_benchmark(lib.sgemm_t_8x8_sliced_k_f32x4_bcf_offset, + a, b, "f32x4(t8x8bcf+offset)", c) + run_benchmark(lib.sgemm_t_8x8_sliced_k_f32x4_bcf_dbuf, + a, b, "f32x4(t8x8dbuf)", c) + run_benchmark(lib.sgemm_t_8x8_sliced_k_f32x4_bcf_dbuf_offset, + a, b, "f32x4(t8x8dbuf+offset)", c) + run_benchmark(partial(torch.matmul, out=c), + a, b, "f32_th") + print("-" * 100) From 9f7d6a112c004dc757deb8d8d6960c53eac6c03f Mon Sep 17 00:00:00 2001 From: DefTruth <31974251+DefTruth@users.noreply.github.com> Date: Mon, 30 Sep 2024 11:42:40 +0800 Subject: [PATCH 5/5] Update README.md --- sgemm/README.md | 138 +++++++++++++++++++++++------------------------- 1 file changed, 66 insertions(+), 72 deletions(-) diff --git a/sgemm/README.md b/sgemm/README.md index 1620dad9..2f5a38f5 100755 --- a/sgemm/README.md +++ b/sgemm/README.md @@ -22,76 +22,70 @@ python3 sgemm.py 输出: ```bash -------------------------------------------------------------------------------------- - M=2048, N=2048, K=1024 - out_f32: [-23.44512749, 105.22006226, -72.40318298], time:2.581863ms - out_f32(sk): [-23.44512749, 105.22006226, -72.40318298], time:1.837885ms - out_f32x4(t8x8sk): [-23.44512749, 105.22006226, -72.40318298], time:0.325584ms - out_f32x4(t8x8bcf): [-23.44512749, 105.22006226, -72.40318298], time:0.298755ms - out_f32x4(t8x8dbuf): [-23.44512749, 105.22006226, -72.40318298], time:0.229251ms - out_f32_th: [-23.44515038, 105.22006226, -72.40312958], time:0.255888ms -------------------------------------------------------------------------------------- -------------------------------------------------------------------------------------- - M=2048, N=2048, K=2048 - out_f32: [4.73375559, -2.49913216, 111.71539307], time:5.155475ms - out_f32(sk): [4.73375559, -2.49913216, 111.71539307], time:3.653073ms - out_f32x4(t8x8sk): [4.73375559, -2.49913216, 111.71539307], time:0.635004ms - out_f32x4(t8x8bcf): [4.73375559, -2.49913216, 111.71539307], time:0.593204ms - out_f32x4(t8x8dbuf): [4.73375559, -2.49913216, 111.71539307], time:0.460200ms - out_f32_th: [4.73375702, -2.49916267, 111.71534729], time:0.467465ms -------------------------------------------------------------------------------------- -------------------------------------------------------------------------------------- - M=2048, N=4096, K=1024 - out_f32: [27.58790588, 18.39359474, -23.69882774], time:5.127516ms - out_f32(sk): [27.58790588, 18.39359474, -23.69882774], time:3.652875ms - out_f32x4(t8x8sk): [27.58790588, 18.39359474, -23.69882774], time:0.626333ms - out_f32x4(t8x8bcf): [27.58790588, 18.39359474, -23.69882774], time:0.549185ms - out_f32x4(t8x8dbuf): [27.58790588, 18.39359474, -23.69882774], time:0.463538ms - out_f32_th: [27.58790588, 18.39359474, -23.69882774], time:0.555634ms -------------------------------------------------------------------------------------- -------------------------------------------------------------------------------------- - M=2048, N=4096, K=2048 - out_f32: [54.19274139, -0.29313943, 26.92167664], time:10.221355ms - out_f32(sk): [54.19274139, -0.29313943, 26.92167664], time:7.268925ms - out_f32x4(t8x8sk): [54.19274139, -0.29313943, 26.92167664], time:1.249781ms - out_f32x4(t8x8bcf): [54.19274139, -0.29313943, 26.92167664], time:1.119103ms - out_f32x4(t8x8dbuf): [54.19274139, -0.29313943, 26.92167664], time:0.960808ms - out_f32_th: [54.19275284, -0.29314613, 26.92167473], time:0.920537ms -------------------------------------------------------------------------------------- -------------------------------------------------------------------------------------- - M=4096, N=2048, K=1024 - out_f32: [-37.67934418, 12.49935532, 40.71273804], time:5.120614ms - out_f32(sk): [-37.67934418, 12.49935532, 40.71273804], time:3.652627ms - out_f32x4(t8x8sk): [-37.67934418, 12.49935532, 40.71273804], time:0.624588ms - out_f32x4(t8x8bcf): [-37.67934418, 12.49935532, 40.71273804], time:0.545461ms - out_f32x4(t8x8dbuf): [-37.67934418, 12.49935532, 40.71273804], time:0.462778ms - out_f32_th: [-37.67934418, 12.49935532, 40.71273804], time:0.560777ms -------------------------------------------------------------------------------------- -------------------------------------------------------------------------------------- - M=4096, N=2048, K=2048 - out_f32: [-15.01755524, -0.44903478, 72.23948669], time:10.213506ms - out_f32(sk): [-15.01755524, -0.44903478, 72.23948669], time:7.269592ms - out_f32x4(t8x8sk): [-15.01755524, -0.44903478, 72.23948669], time:1.242898ms - out_f32x4(t8x8bcf): [-15.01755524, -0.44903478, 72.23948669], time:1.099443ms - out_f32x4(t8x8dbuf): [-15.01755524, -0.44903478, 72.23948669], time:0.941424ms - out_f32_th: [-15.01752663, -0.44904327, 72.23952484], time:0.940223ms -------------------------------------------------------------------------------------- -------------------------------------------------------------------------------------- - M=4096, N=4096, K=1024 - out_f32: [-5.76778412, 22.12718964, 17.76623344], time:10.221822ms - out_f32(sk): [-5.76778412, 22.12718964, 17.76623344], time:7.308133ms - out_f32x4(t8x8sk): [-5.76778412, 22.12718964, 17.76623344], time:1.263077ms - out_f32x4(t8x8bcf): [-5.76778412, 22.12718964, 17.76623344], time:1.134577ms - out_f32x4(t8x8dbuf): [-5.76778412, 22.12718964, 17.76623344], time:1.009488ms - out_f32_th: [-5.76778412, 22.12718964, 17.76623344], time:0.926571ms -------------------------------------------------------------------------------------- -------------------------------------------------------------------------------------- - M=4096, N=4096, K=2048 - out_f32: [35.152565, 56.02351761, 29.87486458], time:20.362103ms - out_f32(sk): [35.152565, 56.02351761, 29.87486458], time:14.596984ms - out_f32x4(t8x8sk): [35.152565, 56.02351761, 29.87486458], time:2.558391ms - out_f32x4(t8x8bcf): [35.152565, 56.02351761, 29.87486458], time:2.313538ms - out_f32x4(t8x8dbuf): [35.152565, 56.02351761, 29.87486458], time:2.144170ms - out_f32_th: [35.152565, 56.02351761, 29.87486458], time:1.896987ms -------------------------------------------------------------------------------------- +---------------------------------------------------------------------------------------------------- + M=2048, N=2048, K=1024 + out_f32: ['-41.69404602', '-15.22974205', '12.31010342 '], time:2.583222ms + out_f32(sk): ['-41.69404602', '-15.22974205', '12.31010342 '], time:1.836123ms + out_f32x4(t8x8sk): ['-41.69404602', '-15.22974205', '12.31010342 '], time:0.324936ms + out_f32x4(t8x8bcf): ['-41.69404602', '-15.22974205', '12.31010342 '], time:0.290537ms + out_f32x4(t8x8bcf+offset): ['-41.69404602', '-15.22974205', '12.31010342 '], time:0.289106ms + out_f32x4(t8x8dbuf): ['-41.69404602', '-15.22974205', '12.31010342 '], time:0.229044ms + out_f32x4(t8x8dbuf+offset): ['-41.69404602', '-15.22974205', '12.31010342 '], time:0.230970ms + out_f32_th: ['-41.69403076', '-15.229743 ', '12.31009007 '], time:0.255721ms +---------------------------------------------------------------------------------------------------- +---------------------------------------------------------------------------------------------------- + M=2048, N=2048, K=2048 + out_f32: ['-11.50634861', '-30.57016182', '14.03067684 '], time:5.152175ms + out_f32(sk): ['-11.50634861', '-30.57016182', '14.03067684 '], time:3.652353ms + out_f32x4(t8x8sk): ['-11.50634861', '-30.57016182', '14.03067684 '], time:0.639246ms + out_f32x4(t8x8bcf): ['-11.50634861', '-30.57016182', '14.03067684 '], time:0.576742ms + out_f32x4(t8x8bcf+offset): ['-11.50634861', '-30.57016182', '14.03067684 '], time:0.575581ms + out_f32x4(t8x8dbuf): ['-11.50634861', '-30.57016182', '14.03067684 '], time:0.460470ms + out_f32x4(t8x8dbuf+offset): ['-11.50634861', '-30.57016182', '14.03067684 '], time:0.465369ms + out_f32_th: ['-11.50632 ', '-30.57013321', '14.03067398 '], time:0.465064ms +---------------------------------------------------------------------------------------------------- +---------------------------------------------------------------------------------------------------- + M=2048, N=4096, K=1024 + out_f32: ['35.35253143 ', '44.40952682 ', '-10.71832466'], time:5.122924ms + out_f32(sk): ['35.35253143 ', '44.40952682 ', '-10.71832466'], time:3.653028ms + out_f32x4(t8x8sk): ['35.35253143 ', '44.40952682 ', '-10.71832466'], time:0.625312ms + out_f32x4(t8x8bcf): ['35.35253143 ', '44.40952682 ', '-10.71832466'], time:0.534370ms + out_f32x4(t8x8bcf+offset): ['35.35253143 ', '44.40952682 ', '-10.71832466'], time:0.530348ms + out_f32x4(t8x8dbuf): ['35.35253143 ', '44.40952682 ', '-10.71832466'], time:0.462132ms + out_f32x4(t8x8dbuf+offset): ['35.35253143 ', '44.40952682 ', '-10.71832466'], time:0.464492ms + out_f32_th: ['35.35253143 ', '44.40952682 ', '-10.71832466'], time:0.557373ms +---------------------------------------------------------------------------------------------------- +---------------------------------------------------------------------------------------------------- + M=2048, N=4096, K=2048 + out_f32: ['61.41757584 ', '107.04826355', '37.28448868 '], time:10.218813ms + out_f32(sk): ['61.41757584 ', '107.04826355', '37.28448868 '], time:7.268655ms + out_f32x4(t8x8sk): ['61.41757584 ', '107.04826355', '37.28448868 '], time:1.237755ms + out_f32x4(t8x8bcf): ['61.41757584 ', '107.04826355', '37.28448868 '], time:1.065564ms + out_f32x4(t8x8bcf+offset): ['61.41757584 ', '107.04826355', '37.28448868 '], time:1.053824ms + out_f32x4(t8x8dbuf): ['61.41757584 ', '107.04826355', '37.28448868 '], time:0.935848ms + out_f32x4(t8x8dbuf+offset): ['61.41757584 ', '107.04826355', '37.28448868 '], time:0.967648ms + out_f32_th: ['61.41755676 ', '107.04829407', '37.28450775 '], time:0.921094ms +---------------------------------------------------------------------------------------------------- +---------------------------------------------------------------------------------------------------- + M=4096, N=2048, K=1024 + out_f32: ['69.17631531 ', '2.35151434 ', '14.92191601 '], time:5.120900ms + out_f32(sk): ['69.17631531 ', '2.35151434 ', '14.92191601 '], time:3.651984ms + out_f32x4(t8x8sk): ['69.17631531 ', '2.35151434 ', '14.92191601 '], time:0.622756ms + out_f32x4(t8x8bcf): ['69.17631531 ', '2.35151434 ', '14.92191601 '], time:0.526509ms + out_f32x4(t8x8bcf+offset): ['69.17631531 ', '2.35151434 ', '14.92191601 '], time:0.529506ms + out_f32x4(t8x8dbuf): ['69.17631531 ', '2.35151434 ', '14.92191601 '], time:0.451362ms + out_f32x4(t8x8dbuf+offset): ['69.17631531 ', '2.35151434 ', '14.92191601 '], time:0.462964ms + out_f32_th: ['69.17631531 ', '2.35151434 ', '14.92191601 '], time:0.552487ms +---------------------------------------------------------------------------------------------------- +---------------------------------------------------------------------------------------------------- + M=4096, N=2048, K=2048 + out_f32: ['62.51137161 ', '-45.17026138', '61.54212952 '], time:10.213661ms + out_f32(sk): ['62.51137161 ', '-45.17026138', '61.54212952 '], time:7.267971ms + out_f32x4(t8x8sk): ['62.51137161 ', '-45.17026138', '61.54212952 '], time:1.244769ms + out_f32x4(t8x8bcf): ['62.51137161 ', '-45.17026138', '61.54212952 '], time:1.076307ms + out_f32x4(t8x8bcf+offset): ['62.51137161 ', '-45.17026138', '61.54212952 '], time:1.074743ms + out_f32x4(t8x8dbuf): ['62.51137161 ', '-45.17026138', '61.54212952 '], time:0.948534ms + out_f32x4(t8x8dbuf+offset): ['62.51137161 ', '-45.17026138', '61.54212952 '], time:0.963700ms + out_f32_th: ['62.51136398 ', '-45.17026138', '61.54217911 '], time:0.916274ms +---------------------------------------------------------------------------------------------------- ```