## What is splitK ?
In a traditional GeMM, each threadblock is responsible for performing the reduction over **ENTIRE K dim**. When K is large this increases the work per threadblock. 

Instead of having 1 tile's output be computed by 1 threadblock and the reduction be done over entire K, **SplitK** will have multiple threablock  $TB_0$, $TB_1$, ... $TB_i$ compute the reduction over K and then use Atomics (atomic add). 

Each threadblock is responsible for computing the partial result over tile vs a complete tile. The partial results then are then reduced using **atomic_add**. Obvisouly as a by product of this, each threadblock has less work to do compared to what a threadblock would do in case of a standard data parallel tile, leading to better SM Utilization. 

## When may splitK be effective ?
1. Memory bound GeMMs: Inference workloads often skinny GeMM
2. Due to splitting along the K dim, each threadblock ....
3. Reduce Wave quantization ...

## 
1. Better loadbalaning for compute bound kernels or useful for memory bound kernels   

## Trade offs
* Note that there is a tension between the improvements from finer grained SM work distribution, and the overhead of thread blocks contending for exclusive write access to the same output buffer. This effect was seen on an A100 where increasing the SplitK parameter from 4 to 16, resulted in a steady degradation of performance as the matrix sizes increased, presumably due to greater wait times per thread block to get exclusive write access to the same memory output buffer.

## Real world
1. In an online inference workloads batch sizes are often small yeilding a skinny activation matrix. 

$C = \beta W A  + C$


Intuitively put, When multiplying two matrices (i.e. performing GEMM), each element of the output is computed as a sum over products along the “K” dimension. If that sum is very long, it can become a bottleneck. SplitK is a strategy that divides this long sum into several shorter, independent sums.

## Readings
1. https://github.com/pytorch-labs/applied-ai/tree/main/kernels
2. https://arxiv.org/pdf/2402.00025
3. https://pytorch.org/blog/accelerating-triton/
4. https://research.colfax-intl.com/cutlass-tutorial-wgmma-hopper/
5. https://pytorch.org/blog/accelerating-moe-model/#30-work-decomposition---splitk
6. https://github.com/pytorch-labs/gpt-fast/tree/main

 


## Kernel Design

Assume a Batch size = 128

* Input (A) = [128, 4096] # notice how this GeMM is skinny with 128 << 4096 
* Weight (W) = [4096, 4096]
* Output (C) = [128, 4096] 

So given these activation and weights matrices we have **M = 128**, **K = 4096**, and **N = 4096**

* BLOCK_M = 16
* BLOCK_N = 16

Threfore each threadblock will process (128/16, 4096/16) = (8, 256) size outputs. **Grid size = 8*256 = 2048**




In [4]:
import triton
import triton.lang as tl 
import torch 
import triton.testing as tt

[31mERROR: Could not find a version that satisfies the requirement triton (from versions: none)[0m[31m
[0m[31mERROR: No matching distribution found for triton[0m[31m
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.0[0m[39;49m -> [0m[32;49m25.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [5]:
BLOCK_M = 128
BLOCK_N = 128


In [7]:
import torch 

@triton.jit
def splitK_kernel(
    a,  # pointer to GPU location of matrix a
    b,   # pointer to GPU location of matrix b
    out, # pointer to GPU location of matrix out
    stride_am, 
    stride_ak, 
    stride_bk, 
    stride_bn, 
    stride_om, 
    stride_on, 
    BLOCK_M: tl.constexpr, 
    BLOCK_N: tl.constexpr, 
    SPLIT_K: tl.constexpr
):
    # ProgramID is a instance of the program and is executed by a thread block. 
    pid_m = tl.progmram_id(0)
    pid_n = tl.progmram_id(1)
    pid_k = tl.progmram_id(2)

    # Starting Indices for each block along m and n. 
    # assume (pid_m, pid_n) = 2, 3 then off_am = 2 * 128 + [0, 1, 2, .... 127]
    # off_am = [256, 257, 258, 259, ..., 383]
    # off_bn = [384, 385, 386, ..., 511]
    # Therefore, this TB processes 128X128 tile dictated by off_am and off_bn. 
    off_am = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) # ROW indices for A and C.
    off_bn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) # COL indices for B and C. 

    # accumulator for storing the partial results of this threadblock. 
    acc = tl.zeroes((BLOCK_M, BLOCK_N))

    num_slices_K = tl.cdiv(k, SPLIT_K)

    for k in range(start=pid_k, stop=num_slices_K, step=SPLIT_K):
        off_k = k * SPLIT_K + tl.arange(SPLIT_K)
        
        # load a block from Matrix A, considering proper boundary conditions to prevent accessing 
        # wrong address spaces. 
        a_ptr = a + off_am[:, None] * stride_am + off_k[:, None] * stride_ak # list of contiguous addresses to read
        a = tl.load(a_ptr, mask=(off_am[:, None] < BLOCK_M) & (off_k[None, :] < SPLIT_K), other=0.0)

        b_ptr = b + off_k[:, None] * stride_bk + off_bn[None, :] * stride_bn
        b = tl.load(b_ptr, mask=(off_k[:, None] < SPLIT_K) & (off_bn[None, :] < BLOCK_N), other=0.0)

        acc += tl.dot(a, b)

    out_ptr = out + pid_m * BLOCK_M * stride_om + pid_n * BLOCK_N * stride_on

    tl.atomic_add(out_ptr, acc)


def splitk_gemm(
        A: torch.tensor,
        B: torch.tensor, 
        splitK: int
):
    M, K = A.shape[0], A.shape[1]
    N = B.shape[1]
    out = torch.empty((M, N), dtype=torch.float16, device=A.device)

    # ------ asserts -------------
    assert A.size() == B.size() == 2, "Inputs must be 2 dim matrices" 
    assert A.shape[1] == B.shape[0], "Incompabile matrices for GeMMs"

    A.contiguous()
    B.contiguous()

    # ----- Grid, Blocks, Warps and EUs --------------
    x_grid = tl.cdiv(M, BLOCK_M)
    y_grid = tl.cdiv(N, BLOCK_N)
    z_grid = splitK
    grid = (x_grid, y_grid, z_grid)

    # --------  kernel lauch --------
    splitK_kernel[grid](
        A, 
        B, 
        out,
        A.stride(0), 
        A.stride(1), 
        B.stride(0), 
        B.stride(1),
        BLOCK_M, 
        BLOCK_N, 
        splitK
    )

    return out


def main():
    M = 128
    K = 4096
    N = 8192
    splitK = 32

    # Activation Matrix = [M, K] = [128, 4096] a skinny gemm since K << M, N
    # Weight Matrix = [K, N] = [4096, 8192]

    A = torch.randn(size=(M, K), dtype=torch.float16).to('cuda')
    W = torch.randn(size=(K, N), dtype=torch.float16).to('cuda')

    C_ref = torch.matmul(A, W)
    C_splitK = splitk_gemm(A, W, splitK)
    print("")
    

## Tuning
relevant hyperparameters for our kernel such as tile sizes, number of warps and the number of pipeline stages

## Kernel Metrics 
Metrics SplitK Data Parallel
Latency(us) 27.90us 52.93us
Global Memory Throughput(GB/s) 313 GB/s 161 GB/s Higher is better
Grid Size(Thread block launch count) 512 128 
Registers 92 150
Shared Memory Usage(KB) 102.40KB 167.94KB
Block Limit (Registers) 5 3
Block Limit (SMEM) 5 2
Achieved Occupancy 27.75 7.55
SM Utilization 43.05% 20.75%

## Notes on Internals 
 
```py
offs_am = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
a_ptr = A + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak
```

and 

```py
# The make_block_ptr function wil use the provided block size to internally compute the full list of off_ams. 
offs_am = pid_m * BLOCK_M  
a_block_ptr = tl.make_block_ptr(
    base=A,
    shape=(M, K),
    strides=(stride_am, stride_ak),
    offsets=(offs_m, 0),
    block_shape=(BLOCK_M, BLOCK_K),
    order=(1, 0)
)
```

are equivalent. 



```py
a_block_ptr = tl.make_block_ptr(
    base=A,                                      # base pointer to the parent tensor
    shapes=(M, K),                               # Shape of the parent tensor
    strides=(stride_mm, stride_mk),              # strides for the parent tensor
    offsets=(pid_m * BLOCK_M, pid_k * SPLIT_K),  # offsets to the block ie tile pointer start
    block_shapes=(BLOCK_M, BLOCK_K),             # The shapes/size of the block 
    order=(0, 1)                                 # (N, T) row major
)

```