In [2]:
uv pip install tilus torch "cuda-python<13"

[2mUsing Python 3.12.6 environment at: /usr/local[0m
[37m⠋[0m [2mResolving dependencies...                                                     [0m[2K[37m⠙[0m [2mResolving dependencies...                                                     [0m[2K[37m⠋[0m [2mResolving dependencies...                                                     [0m[2K[37m⠙[0m [2mResolving dependencies...                                                     [0m[2K[37m⠙[0m [2mtilus==0.1.1                                                                  [0m[2K[37m⠙[0m [2mtorch==2.8.0+cu129                                                            [0m[2K[37m⠙[0m [2mcuda-python==12.9.2                                                           [0m[2K[37m⠹[0m [2mcuda-python==12.9.2                                                           [0m[2K[37m⠹[0m [2mhidet==0.6.1                                                                  [0m[2K[37m⠹[0m [2mtabulate==0

# Goals
 Tiles can be thought of as a somewhat high-level introduction to CUDA because it abstracts many things away from you. In fact, it can be even more effective than Triton for learning. In turn, it can allow us to become more familiar with the high-level objectives of CUDA without many of the common problems you may encounter with it.

# Note
On top, it has a simple installation process and can be quickly walked through in less than a few hours. 

TILUS does allow for explicit control of the memory hierarchy that Triton does not allow, which can make it more precise. 

It inherits principles from TVM, even though that project is considered relatively legacy for now. 

You can access CUDA-specific optimizations, which can be harder to do or less common in Triton. 

Arbitrary low-precision types are especially useful for future developments. For example, FP4 was just launched in September 2025 for CUDA 13.

Overall, it can also be useful for language model inference practitioners who want a deeper dive into various elements used in the day-to-day operations for modeling code and kernels.



# Installation
In this specific modal environment, you'll have to use cuda-python less than 13. 

In [2]:
import torch
import tilus


# define the kernel by subclassing `tilus.Script`
class MyKernel(tilus.Script):
    def __call__(self):
        # the configuration settings
        self.attrs.blocks = 1  # one thread block
        self.attrs.warps = 1  # one warp per thread block

        self.printf("Hello, World!")


# instantiate the kernel
kernel = MyKernel()

# launch the kernel on GPU
kernel()
torch.cuda.synchronize()
# sync causes the kernel to wait for all
# kernels in streams to complete
# it's important, because it basically says:
# "wait for all kernels to finish before proceeding"
# so all subsequent code will be executed only after the kernel has finished
# torch.cuda.synchronize()
# it's sort of like await.
# normally we could just dispatch the cuda execution,
# and let it run free like a kid.

  """Check whether the multi-function fa covers the multi-function fb.
[Building] my_kernel: 100%|███████████████████████████████████████████████████████████████| 1/1 [00:07<00:00,  7.80s/it][Building] my_kernel: 100%|███████████████████████████████████████████████████████████████| 1/1 [00:07<00:00,  7.80s/it]


Hello, World!

In [3]:
import torch
import tilus
from tilus import float32, int32
from tilus.utils import cdiv


class AddOneKernel(tilus.Script):
    def __init__(self, block_n, warps):
        super().__init__()
        self.block_n: int = block_n
        self.warps: int = warps

    def __call__(self, n: int32, a_ptr: ~float32, b_ptr: ~float32):
        self.attrs.blocks = cdiv(n, self.block_n)  # define the number of thread blocks
        self.attrs.warps = self.warps  # define the number of warps per block

        # get the offset for the current block
        offset = self.blockIdx.x * self.block_n

        # create two global tensors for input and output, given their pointers
        ga = self.global_view(a_ptr, shape=[n], dtype=float32)
        gb = self.global_view(b_ptr, shape=[n], dtype=float32)

        # load the inputs from global memory into a register tensor
        a = self.load_global(ga, offsets=[offset], shape=[self.block_n])

        # perform the computation: add 1 to each element in the register tensor
        b = a + 1.0

        # store the result back to global memory
        self.store_global(gb, b, offsets=[offset])


def main():
    # define the kernel
    kernel = AddOneKernel(block_n=128, warps=4)

    # create input and output tensors
    n = 16
    a = torch.arange(n, dtype=torch.float32)
    b = torch.empty_like(a)

    # launch the kernel
    kernel(n, a, b)

    print(a)
    print(b)


main()

[Building] add_one_kernel-d1: 100%|███████████████████████████████████████████████████████| 1/1 [00:04<00:00,  4.16s/it][Building] add_one_kernel-d1: 100%|███████████████████████████████████████████████████████| 1/1 [00:04<00:00,  4.17s/it]

tensor([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13.,
        14., 15.])
tensor([7.7781e+31, 2.0706e-19, 1.8888e+31, 2.7410e+17, 1.3563e-19, 3.9168e-02,
        2.0513e+17, 1.3563e-19, 3.9155e-02, 4.7429e+30, 2.0108e+20, 1.1257e+24,
        3.3844e-12, 7.9309e+34, 6.0022e+31, 4.2964e+24])





# Basic Language Constructs
What can we use in `Tilus`?
The `__init__` is responsible for doing compile time setup, so it will 'pre-compute' over the hyper params, you can use it, or you can just record it. In our example, we set numbers such as number of blocks, number of warps. When we call the init method, those will be set (and the standard python init follows.)

`__call__` method works for defining actual code that we execute. We need to say: How many 'thread blocks' to launch, so it represents the # of blocks in each dimension, since we are using the 3d block model. It's important to understand this prior to working a lot in tilus.

Warps per thread block needs to say the # of warps per thread block. So it is the compilation time constant.


# JIT
 The common pattern with JIT functions is that at the first run, we will compile them and have an optimized version of them. All kernel parameters must be typed and annotated with type. 
 For JIT annotations and non-JIT annotations, the kinds of valid types can differ.

# Objectives
 In the JIT module, you can do a few main things, including on-demand compilation, where the Script class is compiled into distinct kernels for each setup when it's first invoked with those runtime parameters. It can also perform the popular task of exploring parameter space with autotune decorator, which is similar to Triton. 
  And this is all with the overarching goal of performance optimization. 

In [None]:
import math
import torch
import tilus
from tilus import float16, float32, int32
from tilus.utils import cdiv


class Matmul(tilus.Script):
    def __init__(self):
        super().__init__()
        self.block_m = 64
        self.block_n = 128
        self.block_k = 16

    def __call__(self,
            m_size: int32, n_size: int, k_size: int,
            a_ptr: ~float16, b_ptr: ~float16, c_ptr: ~float16
    ):
        self.attrs.blocks = [
            cdiv(m_size, self.block_m),  # the x dimension size of the grid
            cdiv(n_size, self.block_n),  # the y dimension size of the grid
        ]
        self.attrs.warps = 4

        offset_m: int32 = self.block_m * self.blockIdx.x
        offset_n: int32 = self.block_n * self.blockIdx.y

        # create two global tensors `ga` and `gb`
        ga = self.global_view(a_ptr, dtype=float16, shape=[m_size, k_size])
        gb = self.global_view(b_ptr, dtype=float16, shape=[k_size, n_size])

        # create a register tensor `acc` for accumulating the results.
        acc = self.register_tensor(
            dtype=float32, shape=[self.block_m, self.block_n], init=0.0
        )

        # iterate over the k dimension in blocks of size `block_k`.
        for k in range(cdiv(k_size, self.block_k)):
            # calculate the offset for the current block in the k dimension
            offset_k = k * self.block_k

            # load a block of matrix A and B into register tensors `a` and `b`.
            a = self.load_global(
                ga, offsets=[offset_m, offset_k], shape=[self.block_m, self.block_k]
            )
            b = self.load_global(
                gb, offsets=[offset_k, offset_n], shape=[self.block_k, self.block_n]
            )

            # perform the dot product: acc = a @ b + acc
            self.dot(a, b, acc, out=acc)

        # after the loop, we cast the accumulated result `acc` to float16 type
        acc_f16 = self.cast(acc, dtype=float16)

        # store it back to the output matrix C.
        gc = self.global_view(c_ptr, dtype=float16, shape=[m_size, n_size])
        self.store_global(gc, acc_f16, offsets=[offset_m, offset_n])


def main():
    kernel = Matmul()

    for k_size, n_size in [(4096, 4096), (4096, 12288)]:
        for m_size in [1, 4, 8, 16]:
            a = torch.randn(m_size, k_size, dtype=torch.float16, device='cuda') / math.sqrt(k_size)
            b = torch.randn(k_size, n_size, dtype=torch.float16, device='cuda') / math.sqrt(k_size)
            c = torch.empty(m_size, n_size, dtype=torch.float16, device='cuda')

            kernel(m_size, n_size, k_size, a, b, c)
            torch.testing.assert_close(c, torch.matmul(a, b), rtol=1e-2, atol=1e-2)

main()

In [4]:
import math
import torch
import tilus
from tilus import float16, float32, int32
from tilus.utils import cdiv


class Matmul(tilus.Script):
    def __init__(self):
        super().__init__()
        self.block_m = 64
        self.block_n = 128
        self.block_k = 16

    def __call__(
        self,
        m_size: int32,
        n_size: int,
        k_size: int,
        a_ptr: ~float16,
        b_ptr: ~float16,
        c_ptr: ~float16,
    ):
        self.attrs.blocks = [
            cdiv(m_size, self.block_m),  # the x dimension size of the grid
            cdiv(n_size, self.block_n),  # the y dimension size of the grid
        ]
        self.attrs.warps = 4

        offset_m: int32 = self.block_m * self.blockIdx.x
        offset_n: int32 = self.block_n * self.blockIdx.y

        # create two global tensors `ga` and `gb`
        ga = self.global_view(a_ptr, dtype=float16, shape=[m_size, k_size])
        gb = self.global_view(b_ptr, dtype=float16, shape=[k_size, n_size])

        # create a register tensor `acc` for accumulating the results.
        acc = self.register_tensor(dtype=float32, shape=[self.block_m, self.block_n], init=0.0)

        # iterate over the k dimension in blocks of size `block_k`.
        for k in range(cdiv(k_size, self.block_k)):
            # calculate the offset for the current block in the k dimension
            offset_k = k * self.block_k

            # load a block of matrix A and B into register tensors `a` and `b`.
            a = self.load_global(
                ga, offsets=[offset_m, offset_k], shape=[self.block_m, self.block_k]
            )
            b = self.load_global(
                gb, offsets=[offset_k, offset_n], shape=[self.block_k, self.block_n]
            )

            # perform the dot product: acc = a @ b + acc
            self.dot(a, b, acc, out=acc)

        # after the loop, we cast the accumulated result `acc` to float16 type
        acc_f16 = self.cast(acc, dtype=float16)

        # store it back to the output matrix C.
        gc = self.global_view(c_ptr, dtype=float16, shape=[m_size, n_size])
        self.store_global(gc, acc_f16, offsets=[offset_m, offset_n])


def main():
    kernel = Matmul()

    for k_size, n_size in [(4096, 4096), (4096, 12288)]:
        for m_size in [1, 4, 8, 16]:
            a = torch.randn(m_size, k_size, dtype=torch.float16, device="cuda") / math.sqrt(k_size)
            b = torch.randn(k_size, n_size, dtype=torch.float16, device="cuda") / math.sqrt(k_size)
            c = torch.empty(m_size, n_size, dtype=torch.float16, device="cuda")

            kernel(m_size, n_size, k_size, a, b, c)
            torch.testing.assert_close(c, torch.matmul(a, b), rtol=1e-2, atol=1e-2)


main()

[Building] matmul-4096-4096-d1: 100%|█████████████████████████████████████████████████████| 1/1 [00:09<00:00,  9.37s/it][Building] matmul-4096-4096-d1: 100%|█████████████████████████████████████████████████████| 1/1 [00:09<00:00,  9.37s/it]
[Building] matmul-12288-4096-d1: 100%|████████████████████████████████████████████████████| 1/1 [00:05<00:00,  5.24s/it][Building] matmul-12288-4096-d1: 100%|████████████████████████████████████████████████████| 1/1 [00:05<00:00,  5.24s/it]


# Example of a JIT Kernel
This example has a `AcceleratorError Cuda Error`.
Conceretely, we start by defining M, N, and K,

The next step is to create two global tensors, A and B, which are basically just the A and B defined in the matrix multiplication. we have a register tensor that holds the results temporarily to be in register for fast access and write locality. all we have to do is load incrementally in the number of steps of blocks, the matrix A and B.

Then we compute the dot product and we keep on repeating this process. 

In terms of quantization, we cast after the result for precision to float32. This can save us memory. Then we put it into the output matrix, which is not the same as the register tensor. 

unfortunately, at the time of writing, 16 september 2025, this kernel does have an illegal memory access error on a nvidia l4, which should work. we're also using the right version of cuda python, which should be less than 13, which is a strange bug. 


# Tilus Type System
 Languages like CUDA are obviously strongly typed. Python is not. It's important for Tilus to be typed in order to have some guarantee of correctness. 

 Three main types in Tilus are Scalar, Pointer, and Tensor, where Tensors simply are groups of scalars that are defined using pointers for each entry (or whatever row/col-major access exists)

# Scalars
 https://nvidia.github.io/tilus/programming-guides/type-system/scalar-types.html
 It might be interesting to look here to see all of the high- and low-precision floats and integers that exist. 


# Pointers
The syntax is `~<dtype>` where `dtype` is one of the scalar types. Easy hierarchy to understand!
 You can point to a scalar, a pointer, or to a void pointer, which means you're pointing to a data type you don't know yet, which is quote: a generic pointer type. 

 # Register Tensor
 This means the tensor is stored in the register, the fast access system. 
 You can define it as such:
 ```python
 self.register_tensor(dtype=float32, shape=[32, 64])
```

It's optional to specify the layout of the register tensor. The shape is simply a tuple of integers specifying the size of each dimension. (Array? The example above is directly from the documentation. )

Seek the dedicated explanation for register layout to explain how the tensor elements are distributed among each thread in the thread block. (The 3D element)

# Shared Tensor
 This is stored in the shared memory of the GPU thread block, everybody in shared memory needs to be explicitly allocated and deallocated and before we free it we need to make sure nothing is waiting for it to use in its operations. We cannot directly operate on shared tensors. We need to load shared tensors into the appropriate registers, make the computation, and send it back to shared memory. The reason behind this is simple: speed. In one line, each thread in the thread block can access the shared tensor. Outside of the block, that is not true.  
 Example:
```python
self.shared_tensor(dtype=float32, shape=[32, 64])
```

 Again, the shared tensor also has an optional layout parameter to specify the layout of the tensor. A shared layout defines how tensors are stored. The shared layout is a mapping utility from the multi-dimensional shape of the tensor to a linear memory address in the shared memory, for example, row-wise or column-wise, row-major or column-major layout.

 The shared tensor is shared by all threads in the thread block. Everybody can access them; however, we want to optimize the access patterns and data locality, which is another recurring principle. 

# Global Tensor

```python
self.global_tensor(dtype=float32, shape=[32, 64])
```
 The global tensor can be used as a tensor shared by every single thread block in a given kernel. Then the global tensor can be alive for the lifetime of the entire kernel.
 Just like shared tensors, we don't provide direct memory access. They will be loaded into registered tensors. 

 In a global tensor, you can use layout or strides as optional parameters. Normally, we'll assume row-major, which is a compact form. Otherwise, we can use strides to define the strides of the tensor in each dimension, which is the number of elements to skip in each dimension to get to the next entry in that dimension. This will affect how it's laid out in memory. If your layout parameter is provided, it will have some custom mapping from the 3D indices to the linear memory address. This is used rarely in practice.

 # Instructions
 https://nvidia.github.io/tilus/programming-guides/instructions.html

Interesting notes:
 these have basically all of the mathematical methods in elementwise and transforms. We also have load, store, and asynchronous copy instructions. 

 # Instruction Overview
 Here is an overview of some of the most interesting ones. 
 

# Autotuning
Autotuning is just HPO! Each configuration is called a schedule, and run the different schedules, benchmarking each schedule's performance.



In [3]:
from tilus.ir.layout import spatial, local, visualize_layout
print(visualize_layout(local(3, 4)))

  """Check whether the multi-function fa covers the multi-function fb.


RegisterLayout(shape=[3, 4], mode_shape=[3, 4], spatial_modes=[], local_modes=[0, 1])
┌──────┬──────┬───────┬───────┐
│ 0: 0 │ 0: 1 │ 0: 2  │ 0: 3  │
├──────┼──────┼───────┼───────┤
│ 0: 4 │ 0: 5 │ 0: 6  │ 0: 7  │
├──────┼──────┼───────┼───────┤
│ 0: 8 │ 0: 9 │ 0: 10 │ 0: 11 │
└──────┴──────┴───────┴───────┘



Register Layouts (from scratch, with diagrams)
==============================================

0) Big picture
--------------
A **register layout** tells you, for every element of a logical tensor owned by a thread block, **which thread(s)** hold that element and **at what local register index within that thread** it lives. This is fundamentally different from global/shared-memory layouts:

- **Global / shared layout**: “Given a tensor index `(i, j, …)`, what is the byte address in global or shared memory?”
- **Register layout**: “Given a tensor index `(i, j, …)`, which **thread id(s)** store it, and what is the **local register slot** inside each such thread?”

Concretely, a register layout is a mapping:

    (thread_id, local_id)  <->  logical_index  (i.e., (i, j, ..., in d dimensions))

The mapping is many-to-one in general, because multiple threads can legally hold replicas of the same logical element.


1) The two essential kinds of modes
-----------------------------------
We factor the tensor’s shape into **modes** (sub-dimensions). Each mode is then assigned to one of two disjoint categories:

- **Spatial modes** → decide **which thread** gets the element. They index over parallel workers.
- **Local modes**   → decide **where inside the thread’s local register array** the element sits.

This split is the heart of the system. One mode cannot be both spatial and local. Violating that creates contradictions and is disallowed.

We track four attributes:
- `shape`: the overall tensor shape (must match the tensor this layout is for).
- `mode_shape`: the sizes of all sub-dimensions after optional splits. Modes of size `1` are pruned.
- `spatial_modes`: an ordered list of mode indices used to compute `thread_id`.
- `local_modes`:   an ordered list of mode indices used to compute `local_id`.

Order matters in both lists; it defines the linearization order (row-major over the listed modes).


2) From raw tensor shape to `mode_shape`
----------------------------------------
You may optionally **split** each tensor dimension into multiple **modes**. Example:

- Tensor shape `[3, 4]`.
  - Keep the first dimension as is: mode size `3`.
  - Split the second dimension `4` into `2 × 2`.
  - Resulting `mode_shape = [3, 2, 2]`.

Another example:

- Tensor shape `[12, 1, 6]`.
  - Split `12` into `3 × 4`.
  - Keep `1` as is (it will be pruned later).
  - Split `6` into `2 × 3`.
  - Raw mode shape `[3, 4, 1, 2, 3]` → prune ones → `mode_shape = [3, 4, 2, 3]`.

Intuition: you are factorizing a mixed-radix indexer. This gives you dials you can assign to “which thread?” versus “which slot inside the thread?”


3) How a layout answers the central question
--------------------------------------------
**Given a logical tensor index `(i, j, …)`**, the layout determines:
- the corresponding **mode indices** (one per entry in `mode_shape`);
- the subset of those mode indices that are **spatial**, in the order specified by `spatial_modes`;
- the subset that are **local**, in the order specified by `local_modes`;
- a row-major linearization of the spatial subset → `thread_id`;
- a row-major linearization of the local subset → `local_id`.

Pseudocode (normal, non-replicated case):

```

# Inputs:

# shape           ... tensor shape (for bounds)

# mode\_shape      ... e.g., \[m0, m1, m2, ...]

# spatial\_modes   ... e.g., \[s0, s1, ...]   (indices into mode\_shape)

# local\_modes     ... e.g., \[l0, l1, ...]   (indices into mode\_shape)

# idx             ... logical tensor index (multidim), e.g., (i, j, ...)

# Step A: map idx -> mode\_index using mixed-radix decomposition:

# Flatten idx to a single linear p in row-major over 'shape', then unflatten over 'mode\_shape'.

mode\_index = unflatten\_over\_mode\_shape(flatten\_over\_shape(idx, shape), mode\_shape)

# mode\_index has length len(mode\_shape)

# Step B: pick out spatial and local sub-tuples in the requested order:

spatial\_index = \[ mode\_index\[k] for k in spatial\_modes ]
local\_index   = \[ mode\_index\[k] for k in local\_modes ]

# Step C: row-major linearization helpers:

def row\_major\_linear(index\_tuple, shape\_tuple):
\# index\_tuple and shape\_tuple have same length
acc = 0
for t in range(len(index\_tuple)):
acc = acc \* shape\_tuple\[t] + index\_tuple\[t]
return acc

spatial\_shape = \[ mode\_shape\[k] for k in spatial\_modes ]
local\_shape   = \[ mode\_shape\[k] for k in local\_modes ]

thread\_id = row\_major\_linear(spatial\_index, spatial\_shape)
local\_id  = row\_major\_linear(local\_index,   local\_shape)

```

This is exactly what your worked example in §8.3.5.1 does with `mode_shape=[2,2,3,2]`, `spatial_modes=[0,2]`, and `local_modes=[3,1]`:
- `spatial_index = [i//2, j//2]`, shaped `[2,3]` → `thread_id = (i//2)*3 + (j//2)`.
- `local_index   = [j%2,  i%2]`, shaped `[2,2]` → `local_id  = (j%2)*2 + (i%2)`.

ASCII diagram of the idea (each box = one logical element):
```

+----------------------------------+
\| logical tensor (shape H x W)     |
\|   split into modes (mode\_shape)  |
+-------------------+--------------+
|
v
\[split into modes]
mode indices ---> pick & order ---> \[spatial modes] --row-major--> thread\_id
-> pick & order ---> \[local modes]   --row-major--> local\_id

```


4) Two trivial layouts to build intuition
-----------------------------------------
4.1) **Local-only** layout: everything lives in **one thread**; that one thread stores all elements in row-major order:

```

layout = local(3, 4)   # shape = (3,4); spatial\_modes=\[], local\_modes=\[0,1]

Grid diagram (each cell shows "thread\_id : local\_id"):

┌──────┬──────┬───────┬───────┐
│ 0: 0 │ 0: 1 │ 0: 2  │ 0: 3  │
├──────┼──────┼───────┼───────┤
│ 0: 4 │ 0: 5 │ 0: 6  │ 0: 7  │
├──────┼──────┼───────┼───────┤
│ 0: 8 │ 0: 9 │ 0: 10 │ 0: 11 │
└──────┴──────┴───────┴───────┘

```

4.2) **Purely spatial** layout: each element goes to a distinct thread; each thread stores exactly one element:

```

layout = spatial(3, 2)   # shape = (3,2); spatial\_modes=\[0,1], local\_modes=\[]

┌──────┬──────┐
│ 0: 0 │ 1: 0 │
├──────┼──────┤
│ 2: 0 │ 3: 0 │
├──────┼──────┤
│ 4: 0 │ 5: 0 │
└──────┴──────┘

```

These are opposite ends of the spectrum. Real kernels interleave both kinds of modes to balance occupancy, register pressure, coalescing, and math/tensor-core feed rates.


5) Composition: replacing each element with a tile
--------------------------------------------------
**Composition** says: “take each element of an outer layout, and **replace it** with a whole tensor (an inner layout).” Dimensions multiply; mode lists concatenate with adjusted indices.

Composition is **associative** (grouping does not matter) but **not commutative** (order matters).

5.1) Example A — `local(3,4).spatial(2,3)`:

Interpretation: start from one-thread local storage of a 3×4 tile; then fan each element out spatially over `(2,3)` threads.

Result shape: `(3*2, 4*3) = (6,12)`.

```

RegisterLayout(shape=\[6, 12], mode\_shape=\[3, 2, 4, 3],
spatial\_modes=\[1, 3], local\_modes=\[0, 2])

\[Partial grid excerpt; each cell shows "thread\_id : local\_id"]

Row 0:  0:0  1:0  2:0   0:1  1:1  2:1   0:2  1:2  2:2   0:3  1:3  2:3
Row 1:  3:0  4:0  5:0   3:1  4:1  5:1   3:2  4:2  5:2   3:3  4:3  5:3
...

```

Here, each original “local element” became a 2×3 block spread across 6 threads, but the local tiling `[3,4]` is still present as part of `local_id`.

5.2) Example B — `spatial(2,3).local(3,4)`:

Inverse order: first spread over 6 threads, then inside each thread store a 3×4 tile. The global shape is the same `(6,12)`, but the mapping is different:

```

RegisterLayout(shape=\[6, 12], mode\_shape=\[2, 3, 3, 4],
spatial\_modes=\[0, 2], local\_modes=\[1, 3])

Top-left 3×4 is entirely in thread 0 with local\_ids 0..11,
next 3×4 is in thread 1, etc. Different distribution than Example A.

```

ASCII contrast:

```

local(3,4).spatial(2,3):            spatial(2,3).local(3,4):
\[local first, then scatter]         \[scatter first, then local tiles]

```

Composition is the workhorse for building complex MMA-friendly layouts out of small, clear pieces.


6) Tensor Core (PTX MMA) operand layouts
----------------------------------------
PTX manuals document register layouts for MMA operands. Those are precisely **register** layouts in this framework.

Example: operand **C** for `mma.sync.aligned.m16n8k8.f16,f16,f16,f16`.

One concise construction:

```

layout = repeat(2, 1).spatial(8, 4).repeat(1, 2)

RegisterLayout(shape=\[16, 8],
mode\_shape=\[2, 8, 4, 2],
spatial\_modes=\[1, 2],
local\_modes=\[0, 3])

```

Explanation:

- `spatial(8,4)` selects the 32 participating threads (8×4) and determines **which** thread holds each C element.
- The two `repeat(...)` wrappers add **local replication modes** of sizes 2 and 2 respectively, yielding local register indices `[0..3]` per thread for the fragment.
- The final mode order (`spatial_modes=[1,2]`, `local_modes=[0,3]`) matches the published mapping.

Visual pattern (each cell shows “t : r” for thread and local register index):

```

Row 0..7:   threads 0..31 with local regs 0..1
Row 8..15:  threads 0..31 with local regs 2..3

```

You can construct the standard A/B/C operand tilings the same way by composing `spatial`, `local`, and `repeat` blocks in the order the PTX figures imply.


7) Allowing multiple threads to hold the same element
-----------------------------------------------------
Some algorithms need **replication**: multiple threads must read/update the same logical element. The layout explicitly supports this by letting **spatial modes** include a **replication mode**, represented as a **negative size**: `-R` means “replicate R times”.

Example:

```

spatial(3,4):
┌──────┬──────┬───────┬───────┐
│ 0:0  │ 1:0  │ 2:0   │ 3:0   │
├──────┼──────┼───────┼───────┤
│ 4:0  │ 5:0  │ 6:0   │ 7:0   │
├──────┼──────┼───────┼───────┤
│ 8:0  │ 9:0  │ 10:0  │ 11:0  │
└──────┴──────┴───────┴───────┘

```

Reducing over the first dimension (collapsing rows) produces **replication**:

```

reduce(spatial(3,4), dims=\[0])  ⇒  shape=\[4], spatial\_modes=\[-3, 0]

┌────────────────┬────────────────┬─────────────────┬─────────────────┐
│ \[0,4,8]:0      │ \[1,5,9]:0      │ \[2,6,10]:0      │ \[3,7,11]:0      │
└────────────────┴────────────────┴─────────────────┴─────────────────┘

Meaning: the logical column-0 element is held by threads {0,4,8} at local\_id 0, etc.

```

In mapping terms, a replicated spatial mode yields **multiple** `thread_id` results for one logical index. Local indices stay the same; you just return a set of `(thread_id, local_id)` pairs.


8) Column-major vs row-major local/spatial orders
-------------------------------------------------
Changing the **order** of modes in `local_modes` reorders registers inside each thread. Changing the order in `spatial_modes` permutes which thread owns which element.

Row-major local:

```

local(2,3): local\_modes=\[0,1]
┌──────┬──────┬──────┐
│ 0:0  │ 0:1  │ 0:2  │
├──────┼──────┼──────┤
│ 0:3  │ 0:4  │ 0:5  │
└──────┴──────┴──────┘

```

Column-major local:

```

column\_local(2,3): local\_modes=\[1,0]
┌──────┬──────┬──────┐
│ 0:0  │ 0:2  │ 0:4  │
├──────┼──────┼──────┤
│ 0:1  │ 0:3  │ 0:5  │
└──────┴──────┴──────┘

```

Row-major spatial:

```

spatial(2,3): spatial\_modes=\[0,1]
┌──────┬──────┬──────┐
│ 0:0  │ 1:0  │ 2:0  │
├──────┼──────┼──────┤
│ 3:0  │ 4:0  │ 5:0  │
└──────┴──────┴──────┘

```

Column-major spatial:

```

column\_spatial(2,3): spatial\_modes=\[1,0]
┌──────┬──────┬──────┐
│ 0:0  │ 2:0  │ 4:0  │
├──────┼──────┼──────┤
│ 1:0  │ 3:0  │ 5:0  │
└──────┴──────┴──────┘

```

These differences matter for vectorization, coalescing, and matching hardware operand swizzles.


9) The mapping process: full algorithm with replication
-------------------------------------------------------
We extend the earlier pseudocode to handle **replicated spatial modes**. A replicated mode has size `-R` (negative). Treat that as “R choices of a replication lane” that **do not** change the logical index, only which thread copy you pick.

```python

def apply\_layout(shape, mode\_shape, spatial\_modes, local\_modes, idx):
\# 1) idx -> mode\_index via mixed radix decomposition:
p = flatten\_over\_shape(idx, shape)              # linear in row-major over 'shape'
mode\_index = unflatten\_over\_mode\_shape(p, mode\_shape)


# 2) Gather local indices and shapes in order:
local_idx   = [mode_index[k] for k in local_modes]
local_shape = [abs(mode_shape[k]) for k in local_modes]
local_id    = row_major_linear(local_idx, local_shape)

# 3) Gather spatial indices. Separate normal vs replicated:
spatial_idx_vals = []
spatial_shape    = []
rep_factors      = []   # collect replication sizes

for k in spatial_modes:
    sz = mode_shape[k]
    if sz > 0:
        spatial_idx_vals.append(mode_index[k])
        spatial_shape.append(sz)
    else:
        rep_factors.append(-sz)  # replication of size R = -sz

base_thread_id = row_major_linear(spatial_idx_vals, spatial_shape)

# 4) If replicated, return R copies with offsets over a replication lane.
# The actual placement of replication lanes in thread_id space is implementation-defined.
# A simple scheme: treat replication lanes as a leading product factor.
if not rep_factors:
    return [(base_thread_id, local_id)]
else:
    R = 1
    for r in rep_factors: R *= r
    # produce R copies; lane_id in [0..R-1]
    return [ (lane_id * total_threads_per_nonrep + base_thread_id, local_id)
             for lane_id in range(R) ]

# Note: 'total\_threads\_per\_nonrep' depends on how the runtime packs replication lanes into thread\_id space.

# Frameworks may interleave or block them; the semantics are "R identical copies exist", not the exact numbering.
```

In plain language: compute the usual `thread_id` from normal spatial modes; then fan that out across replication lanes. Each fan-out corresponds to “the same logical element appears in multiple threads.”


10) Operations provided by the layout module
--------------------------------------------
10.1) **Creation**
- `spatial(*shape[, ranks])` — make a purely spatial layout (every element is owned by some thread; no local tiling).
- `local(*shape[, ranks])` — make a purely local layout (single thread; everything in registers).
- `column_spatial(*shape)` — same as `spatial` but mode order is column-major (i.e., reversed for 2D).
- `column_local(*shape)` — same as `local` but local mode order is column-major.
- `auto_local_spatial(num_threads, shape)` — produce a `local(...).spatial(...)` that uses the given number of threads for the given tensor `shape`.

10.2) **Transformation** (do **not** change “threads per element” and “elements per thread”; they retile the same resources)
- `squeeze(layout, dims)` — remove size-1 dimensions in `shape` and the corresponding modes.
- `unsqueeze(layout, dims)` — insert size-1 dimensions.
- `permute(layout, dims)` — reorder tensor dimensions (and remap modes consistently).
- `reshape(layout, shape)` — change `shape` while preserving element count; redistributes mode factors accordingly.
- `flatten(layout[, start_dim, end_dim])` — merge a range of dimensions into one.

10.3) **Composition**
- `concat(lhs, rhs)` — concatenate layouts along a dimension (grow shape; threads/elements per thread unchanged per tile).
- `compose(outer, inner)` — replacement rule: each element of `outer` becomes a whole `inner` tile; shapes multiply; spatial/local mode lists combine.

10.4) **Other**
- `divide(lhs, rhs)` — inverse of certain compositions when factorizations match (split a layout into outer/inner components).
- `reduce(layout, dims[, keepdims])` — remove dimensions by combining them; if you reduce away a spatial dimension, you typically introduce **replication** (negative spatial mode sizes).


11) Worked example mirroring §8.3.5.1
-------------------------------------
Given:
- `shape = [4, 6]`
- `mode_shape = [2, 2, 3, 2]`
- `spatial_modes = [0, 2]`
- `local_modes   = [3, 1]`

For a logical index `(i, j)`:

- Compute `mode_index = [ i//2, i%2, j//2, j%2 ]`.
- Spatial subtuple (in listed order) is `[ i//2, j//2 ]` with shape `[2, 3]`.
  → `thread_id = (i//2) * 3 + (j//2)`.
- Local subtuple   (in listed order) is `[ j%2,  i%2 ]` with shape `[2, 2]`.
  → `local_id  = (j%2) * 2 + (i%2)`.

This gives an unambiguous, fast, index-to-(thread,register) mapping suitable for kernel code generation.


12) Inverse mapping (from a location back to a logical index)
-------------------------------------------------------------
When an element is **not replicated**, the mapping is bijective from `(thread_id, local_id)` to a unique logical index.

Sketch:

```

# Given thread\_id, invert spatial linearization to get spatial\_index tuple:

spatial\_index = unlinearize\_row\_major(thread\_id, spatial\_shape)

# Given local\_id, invert local linearization:

local\_index   = unlinearize\_row\_major(local\_id,   local\_shape)

# Now fill a full mode\_index\[] of length len(mode\_shape):

for t, k in enumerate(spatial\_modes): mode\_index\[k] = spatial\_index\[t]
for t, k in enumerate(local\_modes):   mode\_index\[k] = local\_index\[t]

# Finally, re-linearize over mode\_shape, then unflatten over the original 'shape':

p = linearize\_over\_mode\_shape(mode\_index, mode\_shape)
idx = unflatten\_over\_shape(p, shape)

```

With replication present, multiple `(thread_id, local_id)` pairs map to the **same** logical index; the inverse returns that same `idx` regardless of which replica you use.


13) Practical guidance and common pitfalls
------------------------------------------
- Do not assign the same mode to both `spatial_modes` and `local_modes`. This breaks the definition; the implementation forbids it.
- Pay attention to **mode order**. Changing `[0,1]` to `[1,0]` is not cosmetic; it changes stride/contiguity and can make or break coalescing and vectorized register loads.
- Use **composition** to mirror hardware conventions (e.g., tensor core operand fragments). Read the PTX figure, then build it with `spatial`, `local`, and `repeat` in the same logical order.
- Use **reduce** to introduce **replication** when several threads must share an element (e.g., block-wide broadcast without shared-memory staging).
- Remember that **transformations** (`reshape`, `permute`, `flatten`, `squeeze/unsqueeze`) **keep** the count of threads and registers per thread; they only retile how those resources are mapped onto the tensor’s shape.


Appendix: Minimal ASCII “tile replacement” illustration
-------------------------------------------------------
Take an outer 2×2 layout. Replace every cell with an inner 2×3 tile.

Outer:
┌───┬───┐
│ A │ B │
├───┼───┤
│ C │ D │
└───┴───┘

Inner 2×3 (for each letter):
┌───┬───┬───┐
│ a │ b │ c │
├───┼───┼───┤
│ d │ e │ f │
└───┴───┴───┘

Composition (outer ∘ inner) gives a 4×6:
┌───┬───┬───┬───┬───┬───┐
│Aa │Ab │Ac │Ba │Bb │Bc │
├───┼───┼───┼───┼───┼───┤
│Ad │Ae │Af │Bd │Be │Bf │
├───┼───┼───┼───┼───┼───┤
│Ca │Cb │Cc │Da │Db │Dc │
├───┼───┼───┼───┼───┼───┤
│Cd │Ce │Cf │Dd │De │Df │
└───┴───┴───┴───┴───┴───┘

If the outer is `local(...)` and inner is `spatial(...)`, letters map to local slots first, then to threads. Reverse the order and you reverse the ownership pattern.


In [6]:
from tilus import float16, float32, int32
from tilus.utils import cdiv


class MatmulV0(tilus.Script):
    def __init__(self):
        super().__init__()
        # we define three hyperparameters: ``block_m``, ``block_n``, and ``block_k`` to determine the tile size on
        # m, n, and k dimensions for each `thread block` of the kernel.
        self.block_m = 64
        self.block_n = 64
        self.block_k = 16

    def __call__(
        self,
        m_size: int32,  # the size of the m dimension of the input matrix A and output matrix C
        n_size: int,  # the size of the n dimension of the input matrix B and output matrix C
        k_size: int,  # the size of the k dimension of the input matrix A and B
        a_ptr: ~float16,  # the pointer to the input matrix A, which is a 2D tensor of shape [m_size, k_size]
        b_ptr: ~float16,  # the pointer to the input matrix B, which is a 2D tensor of shape [k_size, n_size]
        c_ptr: ~float16,  # the pointer to the output matrix C, which is a 2D tensor of shape [m_size, n_size]
    ):
        self.attrs.blocks = [
            cdiv(m_size, self.block_m),  # the x dimension size of the grid
            cdiv(n_size, self.block_n),  # the y dimension size of the grid
        ]
        self.attrs.warps = (
            1  # the number of warps per thread block, must be a compile-time known integer
        )

        # define two int32 variables to store the offsets of the m and n dimensions for the current thread block.
        offset_m: int32 = self.block_m * self.blockIdx.x
        offset_n: int32 = self.block_n * self.blockIdx.y

        # create two global tensors `ga` and `gb` to represent the input matrices A and B, respectively.
        ga = self.global_view(a_ptr, dtype=float16, shape=[m_size, k_size])
        gb = self.global_view(b_ptr, dtype=float16, shape=[k_size, n_size])

        # create a register tensor `acc` to accumulate the results of the matrix multiplication.
        acc = self.register_tensor(dtype=float32, shape=[self.block_m, self.block_n], init=0.0)

        # iterate over the k dimension in blocks of size `block_k`.
        for k in range(cdiv(k_size, self.block_k)):
            # calculate the offset for the current block in the k dimension
            offset_k = k * self.block_k

            # load a block of matrix A and B into register tensors `a` and `b`.
            a = self.load_global(
                ga, offsets=[offset_m, offset_k], shape=[self.block_m, self.block_k]
            )
            b = self.load_global(
                gb, offsets=[offset_k, offset_n], shape=[self.block_k, self.block_n]
            )

            # perform the dot product: acc = a @ b + acc
            self.dot(a, b, acc, out=acc)

        # after the loop, we cast the accumulated result `acc` to float16 type and store it back to the output matrix C.
        acc_f16 = self.cast(acc, dtype=float16)
        gc = self.global_view(c_ptr, dtype=float16, shape=[m_size, n_size])
        self.store_global(gc, acc_f16, offsets=[offset_m, offset_n])


import pandas
import torch
from tilus.utils import benchmark_func


def main():
    headers = ["m", "n", "k", "name", "latency (ms)", "tflops"]
    workloads = [[4096, 4096, 4096]]

    rows = []
    for m, n, k in workloads:
        # create an instance of the kernel we have just defined
        matmul = MatmulV0()

        a = (torch.rand(m, k, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k)
        b = (torch.rand(k, n, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k)
        c_actual = torch.empty(m, n, dtype=torch.float16).cuda()
        c_expect = a @ b
        torch.cuda.synchronize()

        # launch the kernel by passing required arguments
        matmul(m, n, k, a, b, c_actual)
        torch.cuda.synchronize()

        # check correctness
        torch.testing.assert_close(c_expect, c_actual, atol=1e-2, rtol=1e-2)

        # benchmark
        for name, func in [
            ("torch", lambda: torch.matmul(a, b, out=c_expect)),
            ("tilus", lambda: matmul(m, n, k, a, b, c_actual)),
        ]:
            latency = benchmark_func(func, warmup=5, repeat=20)
            tflops = 2 * m * n * k / latency * 1e-9
            rows.append([m, n, k, name, latency, tflops])

    df = pandas.DataFrame(rows, columns=headers)
    print(df)
if __name__ == "__main__":
    main()

[Building] matmul_v0-4096-4096-d1: 100%|██████████████████████████████████████████████████| 1/1 [00:06<00:00,  6.05s/it][Building] matmul_v0-4096-4096-d1: 100%|██████████████████████████████████████████████████| 1/1 [00:06<00:00,  6.05s/it]


      m     n     k   name  latency (ms)     tflops
0  4096  4096  4096  torch       1.84832  74.358852
1  4096  4096  4096  tilus       4.42880  31.033000
