Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions hgemm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,27 @@
- [X] hgemm_t_8x8_sliced_k_f16x8_pack_bcf_kernel(bank conflicts reduce, pack)
- [X] PyTorch bindings

## 共享内存 Bank Conflicts

含义:在访问shared memory时,因多个线程读写同一个Bank中的不同数据地址时,导致shared memory 并发读写 退化 成顺序读写的现象叫做Bank Conflict;

![](https://github.com/PaddleJitLab/CUDATutorial/blob/develop/docs/09_optimize_reduce/02_bank_conflict/images/ef322be7c3e5b6b9be69d2b90e88083f50569a58a97129f348e483b946ab4edf.png)

SM调度单位为一个warp(一个warp内32个Thread),shared_memory 可以 被一个warp中的所有(32个)线程进行访问,shared_memory 映射到大小相等的32个Bank上,Bank的数据读取带宽为32bit / cycle (4 bytes),因此,主要需要考虑一个Warp内32线程的访问共享内存时的bank冲突。
对于多个线程读取同一个Bank数据时(不同地址),硬件把内存读写请求,拆分成 conflict-free requests,进行顺序读写,此时将会触发多次内存事务。特别地,当一个warp中的所有线程读写同一个地址时,会触发broadcast机制,此时不会退化成顺序读写。上面提到触发broadcast机制的条件是all threads acess same address,但在翻阅cuda-c-programming-guide以及最新版本的[NVProfGuide](https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html) 时,发现只要是多个thread 读写就会触发broadcast(不需要All)。

- 多个线程读同一个数据时,仅有一个线程读,然后broadcast到其他线程
- 多个线程写同一个数据时,仅会有一个线程写成功

[Using Shared Memory in CUDA C/C++](https://developer.nvidia.com/blog/using-shared-memory-cuda-cc/)中指出,我们还可以通过 `cudaDeviceSetSharedMemConfig()` 函数设置默认Bank Size(默认为4 bytes)来避免bank conflicts,可设置为cudaSharedMemBankSizeFourByte或者cudaSharedMemBankSizeEightByte。对于某些场景来说,设置cudaSharedMemBankSizeEightByte或许更合格式。

## 参考文献

- [CUDA编程概念】一、什么是bank conflict?](https://zhuanlan.zhihu.com/p/659142274)
- [解决 bank conflict](https://github.com/PaddleJitLab/CUDATutorial/blob/develop/docs/09_optimize_reduce/02_bank_conflict/README.md)
- [Bank Conflict free 的几种方式](https://zhuanlan.zhihu.com/p/722286440)
- [Using Shared Memory in CUDA C/C++](https://developer.nvidia.com/blog/using-shared-memory-cuda-cc/)

## 测试

```bash
Expand Down
138 changes: 66 additions & 72 deletions sgemm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,76 +22,70 @@ python3 sgemm.py
输出:

```bash
-------------------------------------------------------------------------------------
M=2048, N=2048, K=1024
out_f32: [-23.44512749, 105.22006226, -72.40318298], time:2.581863ms
out_f32(sk): [-23.44512749, 105.22006226, -72.40318298], time:1.837885ms
out_f32x4(t8x8sk): [-23.44512749, 105.22006226, -72.40318298], time:0.325584ms
out_f32x4(t8x8bcf): [-23.44512749, 105.22006226, -72.40318298], time:0.298755ms
out_f32x4(t8x8dbuf): [-23.44512749, 105.22006226, -72.40318298], time:0.229251ms
out_f32_th: [-23.44515038, 105.22006226, -72.40312958], time:0.255888ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
M=2048, N=2048, K=2048
out_f32: [4.73375559, -2.49913216, 111.71539307], time:5.155475ms
out_f32(sk): [4.73375559, -2.49913216, 111.71539307], time:3.653073ms
out_f32x4(t8x8sk): [4.73375559, -2.49913216, 111.71539307], time:0.635004ms
out_f32x4(t8x8bcf): [4.73375559, -2.49913216, 111.71539307], time:0.593204ms
out_f32x4(t8x8dbuf): [4.73375559, -2.49913216, 111.71539307], time:0.460200ms
out_f32_th: [4.73375702, -2.49916267, 111.71534729], time:0.467465ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
M=2048, N=4096, K=1024
out_f32: [27.58790588, 18.39359474, -23.69882774], time:5.127516ms
out_f32(sk): [27.58790588, 18.39359474, -23.69882774], time:3.652875ms
out_f32x4(t8x8sk): [27.58790588, 18.39359474, -23.69882774], time:0.626333ms
out_f32x4(t8x8bcf): [27.58790588, 18.39359474, -23.69882774], time:0.549185ms
out_f32x4(t8x8dbuf): [27.58790588, 18.39359474, -23.69882774], time:0.463538ms
out_f32_th: [27.58790588, 18.39359474, -23.69882774], time:0.555634ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
M=2048, N=4096, K=2048
out_f32: [54.19274139, -0.29313943, 26.92167664], time:10.221355ms
out_f32(sk): [54.19274139, -0.29313943, 26.92167664], time:7.268925ms
out_f32x4(t8x8sk): [54.19274139, -0.29313943, 26.92167664], time:1.249781ms
out_f32x4(t8x8bcf): [54.19274139, -0.29313943, 26.92167664], time:1.119103ms
out_f32x4(t8x8dbuf): [54.19274139, -0.29313943, 26.92167664], time:0.960808ms
out_f32_th: [54.19275284, -0.29314613, 26.92167473], time:0.920537ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
M=4096, N=2048, K=1024
out_f32: [-37.67934418, 12.49935532, 40.71273804], time:5.120614ms
out_f32(sk): [-37.67934418, 12.49935532, 40.71273804], time:3.652627ms
out_f32x4(t8x8sk): [-37.67934418, 12.49935532, 40.71273804], time:0.624588ms
out_f32x4(t8x8bcf): [-37.67934418, 12.49935532, 40.71273804], time:0.545461ms
out_f32x4(t8x8dbuf): [-37.67934418, 12.49935532, 40.71273804], time:0.462778ms
out_f32_th: [-37.67934418, 12.49935532, 40.71273804], time:0.560777ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
M=4096, N=2048, K=2048
out_f32: [-15.01755524, -0.44903478, 72.23948669], time:10.213506ms
out_f32(sk): [-15.01755524, -0.44903478, 72.23948669], time:7.269592ms
out_f32x4(t8x8sk): [-15.01755524, -0.44903478, 72.23948669], time:1.242898ms
out_f32x4(t8x8bcf): [-15.01755524, -0.44903478, 72.23948669], time:1.099443ms
out_f32x4(t8x8dbuf): [-15.01755524, -0.44903478, 72.23948669], time:0.941424ms
out_f32_th: [-15.01752663, -0.44904327, 72.23952484], time:0.940223ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
M=4096, N=4096, K=1024
out_f32: [-5.76778412, 22.12718964, 17.76623344], time:10.221822ms
out_f32(sk): [-5.76778412, 22.12718964, 17.76623344], time:7.308133ms
out_f32x4(t8x8sk): [-5.76778412, 22.12718964, 17.76623344], time:1.263077ms
out_f32x4(t8x8bcf): [-5.76778412, 22.12718964, 17.76623344], time:1.134577ms
out_f32x4(t8x8dbuf): [-5.76778412, 22.12718964, 17.76623344], time:1.009488ms
out_f32_th: [-5.76778412, 22.12718964, 17.76623344], time:0.926571ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
M=4096, N=4096, K=2048
out_f32: [35.152565, 56.02351761, 29.87486458], time:20.362103ms
out_f32(sk): [35.152565, 56.02351761, 29.87486458], time:14.596984ms
out_f32x4(t8x8sk): [35.152565, 56.02351761, 29.87486458], time:2.558391ms
out_f32x4(t8x8bcf): [35.152565, 56.02351761, 29.87486458], time:2.313538ms
out_f32x4(t8x8dbuf): [35.152565, 56.02351761, 29.87486458], time:2.144170ms
out_f32_th: [35.152565, 56.02351761, 29.87486458], time:1.896987ms
-------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
M=2048, N=2048, K=1024
out_f32: ['-41.69404602', '-15.22974205', '12.31010342 '], time:2.583222ms
out_f32(sk): ['-41.69404602', '-15.22974205', '12.31010342 '], time:1.836123ms
out_f32x4(t8x8sk): ['-41.69404602', '-15.22974205', '12.31010342 '], time:0.324936ms
out_f32x4(t8x8bcf): ['-41.69404602', '-15.22974205', '12.31010342 '], time:0.290537ms
out_f32x4(t8x8bcf+offset): ['-41.69404602', '-15.22974205', '12.31010342 '], time:0.289106ms
out_f32x4(t8x8dbuf): ['-41.69404602', '-15.22974205', '12.31010342 '], time:0.229044ms
out_f32x4(t8x8dbuf+offset): ['-41.69404602', '-15.22974205', '12.31010342 '], time:0.230970ms
out_f32_th: ['-41.69403076', '-15.229743 ', '12.31009007 '], time:0.255721ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
M=2048, N=2048, K=2048
out_f32: ['-11.50634861', '-30.57016182', '14.03067684 '], time:5.152175ms
out_f32(sk): ['-11.50634861', '-30.57016182', '14.03067684 '], time:3.652353ms
out_f32x4(t8x8sk): ['-11.50634861', '-30.57016182', '14.03067684 '], time:0.639246ms
out_f32x4(t8x8bcf): ['-11.50634861', '-30.57016182', '14.03067684 '], time:0.576742ms
out_f32x4(t8x8bcf+offset): ['-11.50634861', '-30.57016182', '14.03067684 '], time:0.575581ms
out_f32x4(t8x8dbuf): ['-11.50634861', '-30.57016182', '14.03067684 '], time:0.460470ms
out_f32x4(t8x8dbuf+offset): ['-11.50634861', '-30.57016182', '14.03067684 '], time:0.465369ms
out_f32_th: ['-11.50632 ', '-30.57013321', '14.03067398 '], time:0.465064ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
M=2048, N=4096, K=1024
out_f32: ['35.35253143 ', '44.40952682 ', '-10.71832466'], time:5.122924ms
out_f32(sk): ['35.35253143 ', '44.40952682 ', '-10.71832466'], time:3.653028ms
out_f32x4(t8x8sk): ['35.35253143 ', '44.40952682 ', '-10.71832466'], time:0.625312ms
out_f32x4(t8x8bcf): ['35.35253143 ', '44.40952682 ', '-10.71832466'], time:0.534370ms
out_f32x4(t8x8bcf+offset): ['35.35253143 ', '44.40952682 ', '-10.71832466'], time:0.530348ms
out_f32x4(t8x8dbuf): ['35.35253143 ', '44.40952682 ', '-10.71832466'], time:0.462132ms
out_f32x4(t8x8dbuf+offset): ['35.35253143 ', '44.40952682 ', '-10.71832466'], time:0.464492ms
out_f32_th: ['35.35253143 ', '44.40952682 ', '-10.71832466'], time:0.557373ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
M=2048, N=4096, K=2048
out_f32: ['61.41757584 ', '107.04826355', '37.28448868 '], time:10.218813ms
out_f32(sk): ['61.41757584 ', '107.04826355', '37.28448868 '], time:7.268655ms
out_f32x4(t8x8sk): ['61.41757584 ', '107.04826355', '37.28448868 '], time:1.237755ms
out_f32x4(t8x8bcf): ['61.41757584 ', '107.04826355', '37.28448868 '], time:1.065564ms
out_f32x4(t8x8bcf+offset): ['61.41757584 ', '107.04826355', '37.28448868 '], time:1.053824ms
out_f32x4(t8x8dbuf): ['61.41757584 ', '107.04826355', '37.28448868 '], time:0.935848ms
out_f32x4(t8x8dbuf+offset): ['61.41757584 ', '107.04826355', '37.28448868 '], time:0.967648ms
out_f32_th: ['61.41755676 ', '107.04829407', '37.28450775 '], time:0.921094ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
M=4096, N=2048, K=1024
out_f32: ['69.17631531 ', '2.35151434 ', '14.92191601 '], time:5.120900ms
out_f32(sk): ['69.17631531 ', '2.35151434 ', '14.92191601 '], time:3.651984ms
out_f32x4(t8x8sk): ['69.17631531 ', '2.35151434 ', '14.92191601 '], time:0.622756ms
out_f32x4(t8x8bcf): ['69.17631531 ', '2.35151434 ', '14.92191601 '], time:0.526509ms
out_f32x4(t8x8bcf+offset): ['69.17631531 ', '2.35151434 ', '14.92191601 '], time:0.529506ms
out_f32x4(t8x8dbuf): ['69.17631531 ', '2.35151434 ', '14.92191601 '], time:0.451362ms
out_f32x4(t8x8dbuf+offset): ['69.17631531 ', '2.35151434 ', '14.92191601 '], time:0.462964ms
out_f32_th: ['69.17631531 ', '2.35151434 ', '14.92191601 '], time:0.552487ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
M=4096, N=2048, K=2048
out_f32: ['62.51137161 ', '-45.17026138', '61.54212952 '], time:10.213661ms
out_f32(sk): ['62.51137161 ', '-45.17026138', '61.54212952 '], time:7.267971ms
out_f32x4(t8x8sk): ['62.51137161 ', '-45.17026138', '61.54212952 '], time:1.244769ms
out_f32x4(t8x8bcf): ['62.51137161 ', '-45.17026138', '61.54212952 '], time:1.076307ms
out_f32x4(t8x8bcf+offset): ['62.51137161 ', '-45.17026138', '61.54212952 '], time:1.074743ms
out_f32x4(t8x8dbuf): ['62.51137161 ', '-45.17026138', '61.54212952 '], time:0.948534ms
out_f32x4(t8x8dbuf+offset): ['62.51137161 ', '-45.17026138', '61.54212952 '], time:0.963700ms
out_f32_th: ['62.51136398 ', '-45.17026138', '61.54217911 '], time:0.916274ms
----------------------------------------------------------------------------------------------------
```
70 changes: 64 additions & 6 deletions sgemm/sgemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ __global__ void sgemm_t_8x8_sliced_k_f32x4_kernel(float* a, float* b, float* c,
}
}

template<const int BM=128, const int BN=128, const int BK=8, const int TM=8, const int TN=8>
template<const int BM=128, const int BN=128, const int BK=8, const int TM=8, const int TN=8, const int OFFSET=0>
__global__ void sgemm_t_8x8_sliced_k_f32x4_bcf_kernel(
float* a, float* b, float* c, const int M, const int N, const int K) {

Expand All @@ -169,8 +169,8 @@ __global__ void sgemm_t_8x8_sliced_k_f32x4_bcf_kernel(
const int ty = threadIdx.y;
const int tid = ty * blockDim.x + tx;

__shared__ float s_a[BK][BM];
__shared__ float s_b[BK][BN];
__shared__ float s_a[BK][BM + OFFSET];
__shared__ float s_b[BK][BN + OFFSET];
// __shared__ float s_a[BK][BM + 4];
// __shared__ float s_b[BK][BN + 4];

Expand Down Expand Up @@ -334,7 +334,7 @@ __global__ void sgemm_t_8x8_sliced_k_f32x4_bcf_kernel(
}
}

template<const int BM=128, const int BN=128, const int BK=8, const int TM=8, const int TN=8>
template<const int BM=128, const int BN=128, const int BK=8, const int TM=8, const int TN=8, const int OFFSET=0>
__global__ void sgemm_t_8x8_sliced_k_f32x4_bcf_dbuf_kernel(
float* a, float* b, float* c, const int M, const int N, const int K) {

Expand All @@ -344,8 +344,8 @@ __global__ void sgemm_t_8x8_sliced_k_f32x4_bcf_dbuf_kernel(
const int ty = threadIdx.y;
const int tid = ty * blockDim.x + tx;

__shared__ float s_a[2][BK][BM];
__shared__ float s_b[2][BK][BN];
__shared__ float s_a[2][BK][BM + OFFSET];
__shared__ float s_b[2][BK][BN + OFFSET];

float r_load_a[TM/2];
float r_load_b[TN/2];
Expand Down Expand Up @@ -592,6 +592,34 @@ void sgemm_t_8x8_sliced_k_f32x4_bcf(torch::Tensor a, torch::Tensor b, torch::Ten
);
}

void sgemm_t_8x8_sliced_k_f32x4_bcf_offset(torch::Tensor a, torch::Tensor b, torch::Tensor c) {
CHECK_TORCH_TENSOR_DTYPE(a, torch::kFloat32)
CHECK_TORCH_TENSOR_DTYPE(b, torch::kFloat32)
CHECK_TORCH_TENSOR_DTYPE(c, torch::kFloat32)
const int M = a.size(0);
const int K = a.size(1);
const int N = b.size(1);
CHECK_TORCH_TENSOR_SHAPE(a, M, K)
CHECK_TORCH_TENSOR_SHAPE(b, K, N)
CHECK_TORCH_TENSOR_SHAPE(c, M, N)
constexpr int BM = 128;
constexpr int BN = 128;
constexpr int BK = 8;
constexpr int TM = 8;
constexpr int TN = 8;
constexpr int OFFSET = 4;

dim3 block(BN/TN, BM/TM);
dim3 grid((N + BN - 1) / BN, (M + BM - 1) / BM);

sgemm_t_8x8_sliced_k_f32x4_bcf_kernel<BM, BN, BK, TM, TN, OFFSET><<<grid, block>>>(
reinterpret_cast<float*>(a.data_ptr()),
reinterpret_cast<float*>(b.data_ptr()),
reinterpret_cast<float*>(c.data_ptr()),
M, N, K
);
}

void sgemm_t_8x8_sliced_k_f32x4_bcf_dbuf(torch::Tensor a, torch::Tensor b, torch::Tensor c) {
CHECK_TORCH_TENSOR_DTYPE(a, torch::kFloat32)
CHECK_TORCH_TENSOR_DTYPE(b, torch::kFloat32)
Expand Down Expand Up @@ -619,10 +647,40 @@ void sgemm_t_8x8_sliced_k_f32x4_bcf_dbuf(torch::Tensor a, torch::Tensor b, torch
);
}

void sgemm_t_8x8_sliced_k_f32x4_bcf_dbuf_offset(torch::Tensor a, torch::Tensor b, torch::Tensor c) {
CHECK_TORCH_TENSOR_DTYPE(a, torch::kFloat32)
CHECK_TORCH_TENSOR_DTYPE(b, torch::kFloat32)
CHECK_TORCH_TENSOR_DTYPE(c, torch::kFloat32)
const int M = a.size(0);
const int K = a.size(1);
const int N = b.size(1);
CHECK_TORCH_TENSOR_SHAPE(a, M, K)
CHECK_TORCH_TENSOR_SHAPE(b, K, N)
CHECK_TORCH_TENSOR_SHAPE(c, M, N)
constexpr int BM = 128;
constexpr int BN = 128;
constexpr int BK = 8;
constexpr int TM = 8;
constexpr int TN = 8;
constexpr int OFFSET = 4;

dim3 block(BN/TN, BM/TM);
dim3 grid((N + BN - 1) / BN, (M + BM - 1) / BM);

sgemm_t_8x8_sliced_k_f32x4_bcf_dbuf_kernel<BM, BN, BK, TM, TN, OFFSET><<<grid, block>>>(
reinterpret_cast<float*>(a.data_ptr()),
reinterpret_cast<float*>(b.data_ptr()),
reinterpret_cast<float*>(c.data_ptr()),
M, N, K
);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
TORCH_BINDING_COMMON_EXTENSION(sgemm_naive_f32)
TORCH_BINDING_COMMON_EXTENSION(sgemm_sliced_k_f32)
TORCH_BINDING_COMMON_EXTENSION(sgemm_t_8x8_sliced_k_f32x4)
TORCH_BINDING_COMMON_EXTENSION(sgemm_t_8x8_sliced_k_f32x4_bcf)
TORCH_BINDING_COMMON_EXTENSION(sgemm_t_8x8_sliced_k_f32x4_bcf_offset)
TORCH_BINDING_COMMON_EXTENSION(sgemm_t_8x8_sliced_k_f32x4_bcf_dbuf)
TORCH_BINDING_COMMON_EXTENSION(sgemm_t_8x8_sliced_k_f32x4_bcf_dbuf_offset)
}
Loading