Skip to content

Commit

Permalink
K10: Thread Swizzling
Browse files Browse the repository at this point in the history
  • Loading branch information
siboehm committed Feb 25, 2023
1 parent 1dfc00d commit 9a4dd77
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions src/kernels/10_kernel_warptiling.cuh
Expand Up @@ -26,8 +26,16 @@ template <const int BM, const int BN, const int BK, const int WM, const int WN,
__global__ void __launch_bounds__(NUM_THREADS)
sgemmWarptiling(int M, int N, int K, float alpha, float *A, float *B,
float beta, float *C) {
const uint cRow = blockIdx.y;
const uint cCol = blockIdx.x;
const uint SWIZZLE = 4;
// which swizzle block are we in
const uint swizzleBlockIdx = blockIdx.x / (SWIZZLE * SWIZZLE);
// index inside the swizzle block
const uint swizzleIdx = blockIdx.x % (SWIZZLE * SWIZZLE);

const uint cCol =
(swizzleBlockIdx % (N / BN / SWIZZLE)) * SWIZZLE + (swizzleIdx % SWIZZLE);
const uint cRow =
(swizzleBlockIdx / (N / BN / SWIZZLE)) * SWIZZLE + (swizzleIdx / SWIZZLE);

// Placement of the warp in the threadblock tile
const uint warpIdx = threadIdx.x / WARPSIZE; // the warp this thread is in
Expand Down

0 comments on commit 9a4dd77

Please sign in to comment.