In [1]:
import torch
from functools import partial

import cutlass
import cutlass.cute as cute
from cutlass.cute.runtime import from_dlpack

# Tutorial: Elementwise Add Kernel in CuTe DSL

This tutorial demonstrates how to implement a simple elementwise
addition kernel using the CuTe DSL (Domain Specific Language).



Elementwise Addition
---------------------

Elementwise addition is a fundamental operation in linear algebra.
Given two tensors of the same shape, the operation performs element-wise
addition to produce a result tensor of the same shape.

For two 2D tensors :math:`A` and :math:`B` of shape :math:`(M, N)`,
the elementwise addition operation :math:`C = A + B` is defined as:

$
   C_{i,j} = A_{i,j} + B_{i,j}
$

where:

- $i \in [0, M-1]$ represents the row index
- $j \in [0, N-1]$ represents the column index
- $A_{i,j}$, $B_{i,j}$, and $C_{i,j}$ are the elements at position $(i,j)$ 
  in tensors $A$, $B$, and $C$ respectively

This operation is performed independently for each element position,
making it highly parallelizable and well-suited for GPU implementation.

Naive Elementwise Add Kernel
-----------------------------

Let's start with a naive implementation that loads each element from
$A$ and $B$, adds them, and stores the result back to $C$.

In [2]:
@cute.kernel
def naive_elementwise_add_kernel(
    gA: cute.Tensor,
    gB: cute.Tensor,
    gC: cute.Tensor,
):
    tidx, _, _ = cute.arch.thread_idx()
    bidx, _, _ = cute.arch.block_idx()
    bdim, _, _ = cute.arch.block_dim()

    thread_idx = bidx * bdim + tidx

    # Map thread index to logical index of input tensor
    m, n = gA.shape
    ni = thread_idx % n
    mi = thread_idx // n

    # Map logical index to physical address via tensor layout
    a_val = gA[mi, ni]
    b_val = gB[mi, ni]

    # Perform element-wise addition
    gC[mi, ni] = a_val + b_val

### Structure of the Kernel

The naive kernel simply maps each thread to one element with a 1-to-1 mapping.
In this kernel, we don't use CuTe layout algebra but only use basic
addressing to index the tensor.

We can launch the kernel with the following JIT function:

In [3]:
@cute.jit
def naive_elementwise_add(
    mA: cute.Tensor,
    mB: cute.Tensor,
    mC: cute.Tensor
):
    num_threads_per_block = 256

    m, n = mA.shape
    kernel = naive_elementwise_add_kernel(mA, mB, mC)
    kernel.launch(grid=((m * n) // num_threads_per_block, 1, 1),
                  block=(num_threads_per_block, 1, 1))

M, N = 2048, 2048

a = torch.randn(M, N, device="cuda", dtype=torch.float16)
b = torch.randn(M, N, device="cuda", dtype=torch.float16)
c = torch.zeros(M, N, device="cuda", dtype=torch.float16)

a_ = from_dlpack(a, assumed_align=16)
b_ = from_dlpack(b, assumed_align=16)
c_ = from_dlpack(c, assumed_align=16)

# Compile kernel
naive_elementwise_add_ = cute.compile(naive_elementwise_add, a_, b_, c_)
naive_elementwise_add_(a_, b_, c_)

# verify correctness
torch.testing.assert_close(c, a + b)

### Benchmark performance

Here's a utility function to benchmark our kernel implementations:

In [4]:
def benchmark(callable, *, num_warmups, num_iterations):
    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)

    torch.cuda.synchronize()

    for _ in range(num_warmups):
        callable()

    start_event.record(stream=torch.cuda.current_stream())
    for _ in range(num_iterations):
        callable()
    end_event.record(stream=torch.cuda.current_stream())
    torch.cuda.synchronize()

    elapsed_time = start_event.elapsed_time(end_event)
    avg_time = elapsed_time / num_iterations

    print(f"Average execution time: {avg_time:.4f} ms")
    print(f"Throughput: {(3 * a.numel() * 2) / (avg_time / 1000) / 1e9:.2f} GB/s")

In [5]:
benchmark(partial(naive_elementwise_add_, a_, b_, c_), num_warmups=5, num_iterations=100)

Average execution time: 0.0385 ms
Throughput: 653.44 GB/s


### Performance Analysis

While our naive implementation maps thread indices to contiguous tensor
dimensions for coalesced memory access, it doesn't have enough
in-flight load & store operations to hide memory latency.

According to Little's Law:

$ L = \lambda \times W $

Where:
- $L$ is the average number of items in a system
- $\lambda$ is the average arrival rate of items (bandwidth)
- $W$ is the average time an item spends in the system (latency)

For our elementwise addition kernel:

1. $L$: The number of load & store operations in-flight
2. $\lambda$ (Bandwidth): Data transfer rate between memory and compute units
3. $W$ (Latency): Round-trip delay of memory requests

For memory-bound operations like elementwise addition, performance is
limited by the number of in-flight load & store operations.

## Vectorized Load and Store

To improve performance according to Little's Law, we need to increase the number
of in-flight requests. We can do this by increasing the number of bytes handled
in each load & store operation per thread through vectorized memory access.

Since Ampere GPUs support up to 128-bit per load/store and each element is 32-bit,
we can load 4 elements per vectorized operation on contiguous rows.
CuTe tiling operations make this vectorization straightforward.

Using ``tiled_tensor = cute.zipped_divide(tensor, tiler)``, we can partition the input
``tensor`` into groups of ``tiler`` blocks. For vectorization, we specify ``tiler``
as the block of data each thread accesses (4 contiguous elements in the same row, or ``(1,4)``).
Different threads can then access different blocks by indexing into the 2nd mode of ``tiled_tensor``.

```python
mA : cute.Tensor                           # (2048,2048):(2048,1)
gA = cute.zipped_divide(a, tiler=(1, 4))   # tiled/vectorized => ((1,4),(2048,512)):((0,1),(2048,4))
```

$
    \begin{array}{ccccc}
    & ((1,4) & , & (2048,512)) & : ((0,1),(2048,4)) \\
    & \underbrace{\phantom{(1,4)}}_{tiler} & & \underbrace{\phantom{(2048,512)}}_{threads} & \\
    & \text{\scriptsize per-thread} & & \text{\scriptsize num of tiles}
    \end{array}
$

In [6]:
@cute.kernel
def vectorized_elementwise_add_kernel(
    gA: cute.Tensor,
    gB: cute.Tensor,
    gC: cute.Tensor,
):
    tidx, _, _ = cute.arch.thread_idx()
    bidx, _, _ = cute.arch.block_idx()
    bdim, _, _ = cute.arch.block_dim()

    thread_idx = bidx * bdim + tidx

    # Map thread index to logical index of input tensor
    m, n = gA.shape[1]       # thread-domain
    ni = thread_idx % n
    mi = thread_idx // n

    # Map logical index to physical address via tensor layout
    a_val = gA[(None, (mi, ni))].load()
    b_val = gB[(None, (mi, ni))].load()
    print(f"[DSL INFO] sliced gA = {gA[(None, (mi, ni))]}")
    print(f"[DSL INFO] sliced gB = {gB[(None, (mi, ni))]}")

    # Perform element-wise addition
    gC[(None, (mi, ni))] = a_val + b_val

This vectorized kernel follows a similar structure to its naive non-vectorized counterpart,
with one key difference: the tensor slicing pattern. By using `(None, (mi, ni))` as the slice indices,
we can extract a `(1,4)` sub-tensor from `gA`, `gB` and `gC` like 

```python
gA[(None, (mi, ni))]

```

Then tensor data can be loaded into vector via the `.load()` method.


```
                                         slice
    ((1,4),(2048,512)):((0,1),(2048,4))   ==>  ((1,4)):((0,1))
       ^     ^    ^
       |     |    |
     (None, (mi,  ni))
```

In [7]:
@cute.jit
def vectorized_elementwise_add(
    mA: cute.Tensor,
    mB: cute.Tensor,
    mC: cute.Tensor
):
    threads_per_block = 256

    gA = cute.zipped_divide(mA, (1, 4))
    gB = cute.zipped_divide(mB, (1, 4))
    gC = cute.zipped_divide(mC, (1, 4))

    print(f"[DSL INFO] Tiled Tensors:")
    print(f"[DSL INFO]   gA = {gA}")
    print(f"[DSL INFO]   gB = {gB}")
    print(f"[DSL INFO]   gC = {gC}")

    vectorized_elementwise_add_kernel(gA, gB, gC).launch(
        grid=(cute.size(gC, mode=[1]) // threads_per_block, 1, 1),
        block=(threads_per_block, 1, 1),
    )

a = torch.randn(M, N, device="cuda", dtype=torch.float16)
b = torch.randn(M, N, device="cuda", dtype=torch.float16)
c = torch.zeros(M, N, device="cuda", dtype=torch.float16)

a_ = from_dlpack(a, assumed_align=16)
b_ = from_dlpack(b, assumed_align=16)
c_ = from_dlpack(c, assumed_align=16)

compiled_func = cute.compile(vectorized_elementwise_add, a_, b_, c_)
compiled_func(a_, b_, c_)

# verify correctness
torch.testing.assert_close(c, a + b)

[DSL INFO] Tiled Tensors:
[DSL INFO]   gA = tensor<ptr<f16, gmem, align<16>> o ((1,4),(2048,512)):((0,1),(2048,4))>
[DSL INFO]   gB = tensor<ptr<f16, gmem, align<16>> o ((1,4),(2048,512)):((0,1),(2048,4))>
[DSL INFO]   gC = tensor<ptr<f16, gmem, align<16>> o ((1,4),(2048,512)):((0,1),(2048,4))>
[DSL INFO] sliced gA = tensor<ptr<f16, gmem, align<8>> o ((1,4)):((0,1))>
[DSL INFO] sliced gB = tensor<ptr<f16, gmem, align<8>> o ((1,4)):((0,1))>


In [8]:
benchmark(partial(compiled_func, a_, b_, c_), num_warmups=5, num_iterations=100)

Average execution time: 0.0202 ms
Throughput: 1244.98 GB/s


## TV Layout

Both the naive and vectorized kernels follow a common pattern to map thread indices
to physical addresses:

Step 1: Map thread index to logical M/N coordinates

```python
    mi = thread_idx // n
    ni = thread_idx % n
```

Step 2: Map logical M/N coordinates to physical addresses using the tensor layout

```python
    a[(None, (mi, ni))].load()
```

CuTe uses TV layout to represent this mapping from thread index and value index
(i.e., the 4 elements loaded per thread) to the logical coordinate space of a tensor.
By configuring different TV layouts, we can experiment with different memory access
patterns with minimal code changes.

The following example demonstrates two levels of tiling: at the thread-block level
and at the thread level.

For thread-block level tiling, each input & output tensor is first divided
into a group of ``(TileM, TileN)`` sub-tensors at the host side.

Inside the GPU kernel, we provide the thread-block index to the 2nd mode of the tiled tensor
(``gA[((None, None), bidx)]``), which returns a thread-block local view of
a single ``(TileM, TileN)`` sub-tensor.

For thread level tiling, we compose the sub-tensor (which maps from logical coordinates
to physical addresses) with the TV layout (which maps from thread & value indices to
logical coordinates). This gives us a tiled sub-tensor that maps from thread & value
indices directly to physical addresses.

We then provide the thread index to the tiled sub-tensor (``tidfrgA[(tidx, None)]``)
to get a thread-local view of the data each thread accesses. Note that the thread index
is now in the 1st mode, as the tiled sub-tensor puts the thread mode before the value mode.

In [9]:
@cute.kernel
def elementwise_add_kernel(
    gA: cute.Tensor,
    gB: cute.Tensor,
    gC: cute.Tensor,
    tv_layout: cute.Layout
):
    tidx, _, _ = cute.arch.thread_idx()
    bidx, _, _ = cute.arch.block_idx()

    #--------------------------------
    # slice for thread-block level view
    #--------------------------------
    blk_coord = ((None, None), bidx)

    # logical coord -> address
    blkA = gA[blk_coord]  # (TileM, TileN) -> physical address
    blkB = gB[blk_coord]  # (TileM, TileN) -> physical address
    blkC = gC[blk_coord]  # (TileM, TileN) -> physical address

    #--------------------------------
    # compose for thread-index & value-index to physical mapping
    #--------------------------------
    # blockA:    (TileM, TileN) -> physical address
    # tv_layout: (tid, vid)     -> (TileM, TileN)
    # tidfrgA = blkA o tv_layout
    # tidfrgA:   (tid, vid) -> physical address
    tidfrgA = cute.composition(blkA, tv_layout)
    tidfrgB = cute.composition(blkB, tv_layout)
    tidfrgC = cute.composition(blkC, tv_layout)

    print(f"Composed with TV layout:")
    print(f"  tidfrgA: {tidfrgA.type}")

    #--------------------------------
    # slice for thread-level view
    #--------------------------------
    # `None` represent slice of the entire per-thread data
    thr_coord = (tidx, None)

    # slice for threads: vid -> address
    thrA = tidfrgA[thr_coord]  # (V) -> physical address
    thrB = tidfrgB[thr_coord]  # (V) -> physical address
    thrC = tidfrgC[thr_coord]  # (V) -> physical address

    thrC[None] = thrA.load() + thrB.load()

If we take a closer look at the layout of zipped divided input tensor `gA`:

```
Tiled to Thread Block:

    ((16,256),(128,8))  : ((2048,1),(32768,256))
     ~~~~~~~~  ~~~~~~      ~~~~~~~~
        |        |            |
        |        |            |
        |        `------------------------> Number of Thread Blocks
        |                     |
        |                     |
        `--------------------'
                  |
                  V
             Thread Block
                 Tile

Sliced to Thread-Block local sub-tensor (a (16, 256) tile):  gA[((None, None), bidx)]

    (16,256)   :  (2048,1)
     ~~~~~~        ~~~~~~
        |             |        Tiled/Composed with TV Layout
        |             |    
        |             |    o   ((32,4),(8,4)):((128,4),(16,1))
        V             V         
~~~~~~~~~~~~~~~     ~~~~~~~~~~~~~~~~~~~ 
((32,4), (8,4))  :  ((4,8192),(1,2048))
    |      |
    |      `--------> per thread fragment
    |
Thread Block
  Shape

Sliced to Thread local sub-tensor (a (4,8) tile):  tidfrgA[(tidx, None)]

```

The host code below shows the construction of the TV layout. By composing
a thread layout of ``(4,32):(32,1)`` (32 threads read contiguous elements on the row dimension,
then 4 warps read different rows) with a value layout of ``(4,8):(8,1)`` (each thread reads
8 contiguous elements on the row dimension across 4 contiguous rows),
we obtain the TV layout shown in the figure above.

In [10]:
@cute.jit
def elementwise_add(
    mA: cute.Tensor,
    mB: cute.Tensor,
    mC: cute.Tensor,
):
    # mA layout: (M, N):(N, 1)
    # TV layout map thread & value index to (16, 256) logical tile
    #  - contiguous thread index maps to mode-1 because input layout is contiguous on
    #     mode-1 for coalesced load-store
    #  - each thread load 8 contiguous element each row and load 4 rows
    thr_layout = cute.make_layout((4, 32), stride=(32, 1))
    val_layout = cute.make_layout((4, 8), stride=(8, 1))
    tiler_mn, tv_layout = cute.make_layout_tv(thr_layout, val_layout)
    print(f"Tiler: {tiler_mn}")
    print(f"TV Layout: {tv_layout}")

    gA = cute.zipped_divide(mA, tiler_mn)  # ((TileM, TileN), (RestM, RestN))
    gB = cute.zipped_divide(mB, tiler_mn)  # ((TileM, TileN), (RestM, RestN))
    gC = cute.zipped_divide(mC, tiler_mn)  # ((TileM, TileN), (RestM, RestN))

    print(f"Tiled Input Tensors:")
    print(f"  gA: {gA.type}")
    print(f"  gB: {gB.type}")
    print(f"  gC: {gC.type}")

    # Launch the kernel asynchronously
    # Async token(s) can also be specified as dependencies
    elementwise_add_kernel(
        gA, gB, gC, tv_layout
    ).launch(
        grid=[cute.size(gC, mode=[1]), 1, 1],
        block=[cute.size(tv_layout, mode=[0]), 1, 1],
    )

a = torch.randn(M, N, device="cuda", dtype=torch.float16)
b = torch.randn(M, N, device="cuda", dtype=torch.float16)
c = torch.zeros(M, N, device="cuda", dtype=torch.float16)

a_ = from_dlpack(a, assumed_align=16)
b_ = from_dlpack(b, assumed_align=16)
c_ = from_dlpack(c, assumed_align=16)

elementwise_add_ = cute.compile(elementwise_add, a_, b_, c_)
elementwise_add_(a_, b_, c_)

# verify correctness
torch.testing.assert_close(c, a + b)

Tiler: (16, 256)
TV Layout: ((32,4),(8,4)):((128,4),(16,1))
Tiled Input Tensors:
  gA: !cute.memref<f16, gmem, align<16>, "((16,256),(128,8)):((2048,1),(32768,256))">
  gB: !cute.memref<f16, gmem, align<16>, "((16,256),(128,8)):((2048,1),(32768,256))">
  gC: !cute.memref<f16, gmem, align<16>, "((16,256),(128,8)):((2048,1),(32768,256))">
Composed with TV layout:
  tidfrgA: !cute.memref<f16, gmem, align<16>, "((32,4),(8,4)):((8,8192),(1,2048))">


In [11]:
benchmark(partial(elementwise_add_, a_, b_, c_), num_warmups=5, num_iterations=200)

Average execution time: 0.0222 ms
Throughput: 1133.58 GB/s


### Using Lambda Function

CuTe DSL is built on top of Python. It can leverage Python to implement meta-programming to generate flexible kernels.
E.g. we can write kernel template that take custom binary operations to generate kernels for arbitrary binary operations.


```python
@cute.jit
def elementwise_apply(
    op: cutlass.Constexpr,
    mA: cute.Tensor,
    mB: cute.Tensor,
    mC: cute.Tensor
):
    ...

```

In [12]:
@cute.kernel
def elementwise_apply_kernel(
    op: cutlass.Constexpr,    # lambda function must be const expr to generate code at compile time
    gA: cute.Tensor,
    gB: cute.Tensor,
    gC: cute.Tensor,
    tv_layout: cute.Layout
):
    tidx, _, _ = cute.arch.thread_idx()
    bidx, _, _ = cute.arch.block_idx()

    blk_coord = ((None, None), bidx)

    # logical coord -> address
    blkA = gA[blk_coord]  # (TileM, TileN) -> physical address
    blkB = gB[blk_coord]  # (TileM, TileN) -> physical address
    blkC = gC[blk_coord]  # (TileM, TileN) -> physical address

    tidfrgA = cute.composition(blkA, tv_layout)
    tidfrgB = cute.composition(blkB, tv_layout)
    tidfrgC = cute.composition(blkC, tv_layout)

    print(f"Composed with TV layout:")
    print(f"  tidfrgA: {tidfrgA.type}")

    thr_coord = (tidx, None)

    # slice for threads: vid -> address
    thrA = tidfrgA[thr_coord]  # (V) -> physical address
    thrB = tidfrgB[thr_coord]  # (V) -> physical address
    thrC = tidfrgC[thr_coord]  # (V) -> physical address

    #--------------------------------
    # apply custom operation
    #--------------------------------
    thrC[None] = op(thrA.load(), thrB.load())


@cute.jit
def elementwise_op(
    op: cutlass.Constexpr,
    mA: cute.Tensor,
    mB: cute.Tensor,
    mC: cute.Tensor,
):
    # mA layout: (M, N):(N, 1)
    # TV layout map thread & value index to (16, 256) logical tile
    #  - contiguous thread index maps to mode-1 because input layout is contiguous on
    #     mode-1 for coalesced load-store
    #  - each thread load 8 contiguous element each row and load 4 rows
    thr_layout = cute.make_layout((4, 32), stride=(32, 1))
    val_layout = cute.make_layout((4, 8), stride=(8, 1))
    tiler_mn, tv_layout = cute.make_layout_tv(thr_layout, val_layout)
    print(f"Tiler: {tiler_mn}")
    print(f"TV Layout: {tv_layout}")

    gA = cute.zipped_divide(mA, tiler_mn)  # ((TileM, TileN), (RestM, RestN))
    gB = cute.zipped_divide(mB, tiler_mn)  # ((TileM, TileN), (RestM, RestN))
    gC = cute.zipped_divide(mC, tiler_mn)  # ((TileM, TileN), (RestM, RestN))

    print(f"Tiled Input Tensors:")
    print(f"  gA: {gA.type}")
    print(f"  gB: {gB.type}")
    print(f"  gC: {gC.type}")

    # Launch the kernel asynchronously
    # Async token(s) can also be specified as dependencies
    elementwise_apply_kernel(
        op, gA, gB, gC, tv_layout
    ).launch(
        grid=[cute.size(gC, mode=[1]), 1, 1],
        block=[cute.size(tv_layout, mode=[0]), 1, 1],
    )

a = torch.randn(M, N, device="cuda", dtype=torch.float16)
b = torch.randn(M, N, device="cuda", dtype=torch.float16)
c = torch.zeros(M, N, device="cuda", dtype=torch.float16)

a_ = from_dlpack(a, assumed_align=16)
b_ = from_dlpack(b, assumed_align=16)
c_ = from_dlpack(c, assumed_align=16)

from operator import mul

elementwise_op(mul, a_, b_, c_)

# verify correctness
torch.testing.assert_close(c, mul(a, b))

Tiler: (16, 256)
TV Layout: ((32,4),(8,4)):((128,4),(16,1))
Tiled Input Tensors:
  gA: !cute.memref<f16, gmem, align<16>, "((16,256),(128,8)):((2048,1),(32768,256))">
  gB: !cute.memref<f16, gmem, align<16>, "((16,256),(128,8)):((2048,1),(32768,256))">
  gC: !cute.memref<f16, gmem, align<16>, "((16,256),(128,8)):((2048,1),(32768,256))">
Composed with TV layout:
  tidfrgA: !cute.memref<f16, gmem, align<16>, "((32,4),(8,4)):((8,8192),(1,2048))">


Custom operators can be more complex. For example, here's a function that performs
multiplication followed by ReLU:

In [13]:
def mul_relu(a, b):
    tmp = a * b
    return cute.where(tmp > 0, tmp, cute.full_like(tmp, 0))


# As we uses cute.where in customized operation, we need to create another relu function
def mul_relu_ref(a, b):
    tmp = a * b
    return torch.relu(tmp)


elementwise_op(mul_relu, a_, b_, c_)

# verify correctness
torch.testing.assert_close(c, mul_relu_ref(a, b))

Tiler: (16, 256)
TV Layout: ((32,4),(8,4)):((128,4),(16,1))
Tiled Input Tensors:
  gA: !cute.memref<f16, gmem, align<16>, "((16,256),(128,8)):((2048,1),(32768,256))">
  gB: !cute.memref<f16, gmem, align<16>, "((16,256),(128,8)):((2048,1),(32768,256))">
  gC: !cute.memref<f16, gmem, align<16>, "((16,256),(128,8)):((2048,1),(32768,256))">
Composed with TV layout:
  tidfrgA: !cute.memref<f16, gmem, align<16>, "((32,4),(8,4)):((8,8192),(1,2048))">
