# Kernel 4: 1D Blocktiling

This next improvement is very similar to the last kernel, but now we do something quite unheard of up until now - threads will calculate more than a single entry in the output matrix.

Yes, you heard me right - a groundbreaking revolution it seems!

Let's just first look at the overall process, to start from a familiar place. 

![](../../images/GEMM1/blocktilingouter.png)

This is the main outer loop, and as you can see, looks pretty similar to our previous kernel. We got a chunk of C being computed using sliding windows of chunks from A and B, with each intermediate chunk (of A and B) being loaded into shared memory. Also, as before, each output chunk corresponds to a a single block.

Note, that BK is set as 8.

The key difference, as mentioned, is that each thread will not compute a small column of entries of C, instead of a single entry in C

This can be illustrated by the following:

![](../../images/GEMM1/blocktilinginner.png)

This is the inner loop where the partial dot products using each chunk (of both A and B) are computed. The picture is misleading since it shows a warp as being ~4 threads (it's still 32 threads) but each thread computes a small column - TM number of elements to be exact. This means that each warp will calculate a 2D section of the chunk.

This is just an overview, let's do with our usual practice and dive into the code step by step to really flash a light on this optimization.

In [None]:
// advance pointers to the starting positions
A += cRow * BLOCKSIZE * K;                    // row=cRow, col=0
B += cCol * BLOCKSIZE;                        // row=0, col=cCol
C += cRow * BLOCKSIZE * N + cCol * BLOCKSIZE; // row=cRow, col=cCol

__shared__ float As[BLOCKSIZE][BLOCKSIZE];
__shared__ float Bs[BLOCKSIZE][BLOCKSIZE];

float threadResults[TM] = {0.0};

First, as before, we advance all the pointers to starting positions so we can use local thread indexing, and then keep moving the pointers further, in the outer loop, to move through the row (in A) and column (in B) in a sliding window function. 

We also allocate the two shared memory arrays.

Then, we allocate a thread-local cache, so that each individual thread can store its TM amount of results. It initiallizes all the elements to 0.0 (duh, I know).


In [None]:
// outer loop
for (uint bkIdx = 0; bkIdx < K; bkIdx += BLOCKSIZE) {
    // populate the SMEM caches (same as before)
    As[threadIdx.y * BLOCKSIZE + threadIdx.x] = A[innerRowA * K + innerColA];
    Bs[threadIdx.y * BLOCKSIZE + threadIdx.x] = B[innerRowB * N + innerColB];
    __syncthreads();
    // ... rest of outer loop body
}

This is the main outer loop, that iterates over the chunks from A and B. The actual SMEM caches are done in a very similar way, BUT, we can see that something is a bit different. This is evidently due to the fact that for each outer loop, a thread will not just grab one element from A and one element from B, to compute a partial dot product for its one entry in C. Now, it needs multiple elements, since it will calculate a column of results.

What this means is that each element will load in TM entries into the SMEM caches < >

Then, as before, we do a __syncthreads() call to ensure that all the necessary chunks have been loaded before we start computing partial dot products.

Taking a further look into the outer loop, where we start computing partial dot products after loading in the entire chunk(s):


In [None]:
 // calculate per-thread results
  for (uint dotIdx = 0; dotIdx < BK; ++dotIdx) {
    // .. first inner loop
  }

This first inner loop happens BK times - which is the number of columns of As and the number of rows of B. It's been a while since I pulled out a .gif but here's one to quickly illustrate what this first inner loop is doing

This was also the moment I realized the super cool images Simon drew were not his personal hand-written notes and rather made using this wonderful site - [Excalidraw](https://excalidraw.com/), so I thought I would use it for my gifs too for consistency.

![](../../images/GEMM1/blocktilingbasic.gif)

So it simply just moves along the columns of As and rows of Bs to compute the partial dot product for this outer loop (like the gif said).

What we notice is that each iteration involves, multiplying a whole column of As with a single element of Bs, to get the first partial-partial (not a typo) dot product for each of the TM entries that our thread is responsible for computing. 

E.g. if our thread needs to find 4 (this would be our TM) elements of the output matrix, it needs 4 complete rows from As and 1 column from Bs since it will find a single column of output entries. We could do this row by row i.e. multiply row for the first entry, by the column in B, then multiply row for second entry by the column in B, so on and so forth. But come on now, something so naive probably seems silly to you at this point - we are optimisation-hungry!

So we will take the first column of the 4 corresponding rows of As, and the first element in the corresponding column of Bs, and use this to find the first part of the partial dot products for all 4 entries. Then it will take the second column of the 4 corresponding rows of As and the second element in the corresponding column of Bs, so on and so forth, until we find the partial dot products for all 4 entries that our thread is responsible for - it's work for the current chunk is done.

This can be seen by looking into the inner loop, and guess what we find? :


In [None]:
    float Btmp = Bs[dotIdx * BN + threadCol];
    for (uint resIdx = 0; resIdx < TM; ++resIdx) {
      threadResults[resIdx] +=
          As[(threadRow * TM + resIdx) * BK + dotIdx] * Btmp;
    }

So, each first inner loop iteration is essentially what I was describing a column of As (TM long) multiplied by a single element of Bs, which is why we simply cache this element of Bs so we can reuse it for all the elements in the column of As.

Then we simply do the multiplications and store it in the corresponding resIdx index in the thread-local array that we created earlier, where the multiplications will add up to the final partial dot product.

In [None]:
__syncthreads();

A += BK;
B += BK * N;

Still inside the outermost loop, we do the __syncthreads() call as before to ensure that no warps rush ahead to the next outer loop iteration and change As and Bs. We also move the pointers for A and B to move onto the next chunk.

### Performance Talk
So now onto the gains - 2.2x faster than the previous kernel!

Let's try to see why this is based on the memory accesses performed by each thread - again remember why the previous kernel had some performance gains to desire: the FMA instructions were stalling, waiting for the memory accesses i.e. loads to finish.

For the previous kernel,
- For global memory, we have K/BLOCKSIZE iterations in the outer loop, since we need to slide across the K dimension (columns of A and rows of B), in increments of the chunks we are grabbing from them, which are BLOCKSIZE big. And for each of these outer loops, each thread makes two global memory loads when accessing the elements from A and B, to put into As and Bs (loading from A and B is the GMEM access).
- For shared memory, we have K/BLOCKSIZE iterations from the outer loop, once again, multiplied with BLOCKSIZE number of iterations in the inner loop since we iterate through the whole BLOCKSIZE-sized row and column to get the partial dot product, and in this inner loop we also have two SMEM accesses, to compute the product of each As and Bs element pair

Thus, we have 
K/BLOCKSIZE * 2 = K/32 * 2 = K/16 GMEM accesses per result
K/BLOCKSIZE * BLOCKSIZE * 2 = K/32 * 32 * 2 = 2K SMEM accesses per result

Now, for the current kernel,
- For global memory, now we have K/BK iterations of the outer loop, since we don't have square blocks and BLOCKSIZES anymore, but we still slide along the columns of A and rows of B. Now, we have two * TM global memory loads, as for each of the TM elements a thread is responsible for, it grabs a corresponding value from A and from B.
- For shared memory we have, the K/BK iterations of the outer loop multiplied by BK iterations of the inner loop since that's the number of partial-partial dot products we have to find i.e. the column we use from Bs has BK elements and we need to loop through all of them. For each inner loop we have one shared memory access to grab the Btmp element from Bs and TM shared memory accesses to get the TM elements in each row from As.

Thus, we have
K/BK * 2 * TM = K/8 * 2 * 8 = K/4. Since we have 8 results that are computed with this, K/32 GEMM accesses per result
K/BK * BK * (1 + TM) = K/8 * 8 * (1+8) = 9K. Since we have 8 results that are computed with this, 9K/8 SMEM accesses per result

So, clearly, a lot less trips to both global and shared memory are being done for each result entry in the matrix.

This can be further seen by revisting the plot we saw for the previous kernel:

![](../../images/GEMM1/warpstates2.png)

This shows us that we spend much fewer cycles per instruction stalling.

### Compiler Optimizations, More PTX
In our inner loops, we first cached the entry of B since it would be reused, and found partial-partial dot products of each of the TMP elements instead of finding the partial dot product for each TMP element one by one - since we made fun of this naive way. Well, funnily enough, if we did it this naive way, so:

In [None]:
for (uint resIdx = 0; resIdx < TM; ++resIdx) {
  for (uint dotIdx = 0; dotIdx < BK; ++dotIdx) {
    threadResults[resIdx] +=
      As[(threadRow * TM + resIdx) * BK + dotIdx] * Bs[dotIdx * BN + threadCol];
  }
}

There is actually no effect on performance. Weird, right?

This is strange especially, as even if we check out the math behind the number of SMEM accesses, we have BK * TM * 2 = 8 * 8 * 2 = 128 accesses, whereas before we had BK * (1 + TM) = 8 * (1 + 8) = 72 accesses.

At instances like this, where even Simon is scratching his head a little, he advises us to look towards the PTX for some clues.

Since, the wisdom to be seen in shown across loop iterations, and assembly code is quite verbose anyway, I will not share the whole assembly but I will explain what happens.

First, a quick aside on a concept called loop unrolling. This is a compiler transformation where loops with small, fixed iteration counts are expanded into straight-line code, with the branch/loop overhead being removed. Seems very odd but this is a key to the mystery.

In our (naive) case, the compiler, instead of generating the loop structure, just unrolls both loops, making a straight-line code with 64 FMAs i.e. the first 8 lines of code will be the all the 8 multiplications needed for the first of the TM elements, then next 8 lines will be for the next element, and so on and so forth.

Since the compiler has foresight of these 64 lines of code, it will see that for the same dotIdx (inner-inner loop counter) the same element from Bs is reloaded, and will automatically create a cache (like Btmp) for that element, allowing it to be re-used. This means that for each Bs element needed, there are 7 less accesses than the naive code would suggest, as after the first of the TM element accesses all Bs elements in the column being checked, all the other TM-1 elements will have access to these Bs values.

This means that we have BK(TM-1) = 8(8-1) = 8(7) = 56 less SMEM accesses. And, knowing that we thought this naive method gives us 128 accesses, we can see that 128 - 56 = 72, the same number of SMEM accesses as our optimized version.

### Arithmetic Intensity

Closing this kernel off, let's first talk about another new concept: arithmetic intensity.

Defined as _the number of FLOPs executed per byte of data transferred (load + store) between GMEM and SMEM_

I'll share a really good photo by Simon that kinda just speaks for itself, about how calculating more results per thread raises arithmetic intensity:

![](../../images/GEMM1/arithmeticintensity.png)

Also, maybe something we should have looked at before this kernel, haha, but:

![](../../images/GEMM1/arithmeticintensity2.png)

So, essentially, all our kernels will perform the same number of FLOPs, they have to, for the same matrix multiplication. We just use the naive $O(N^3)$ algorithm, and don't attempt to use a faster algorithm to compute the arithmetic behind matrix multiplication as they have constraints which often increase memory traffic. The whole point of modern GPUs is that they are so fast at computation, that the only real bottle neck is memory bandwidth, not the number of FLOPs.

Since the number of FLOPs will stay the same, the only thing we will optimize is the memory bandwidth, so , essentially have to increase the number of FLOPs per memory load which will happen when we have lesser loads. 

![](../../images/GEMM1/arithmeticintensityeq.png)



There we have it then, we have to continue to keep optimizing arithmetic intensity. Onwards to the next kernel!