# Write your own kernel with the Neuron Kernel Interface (NKI)
In this notebook you'll learn how to develop your own kernel with [NKI](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/nki/index.html). A kernel is a set of user-defined functions that are executed largely as defined by the user, not by the compiler. With NKI you can write your own functions to define any operations you like, using supported APIs, and execute them on Trainium and Inferentia hardware. You have the control and lower-level access to define the data movement, computational patterns, and physical execution for the mathematics of your algorithms with NKI.

The structure of the notebook is as follows:
1. Brief introduction to the NeuronCore and the NKI programming model
2. Your first NKI kernel - tensor addition
3. Your second NKI kernel - matrix multiplication

Wrap up and next steps.

### 1. Introduction to the NeuronCore and NKI programming model
The NeuronCore is the main acceleration unit within AWS AI chips Trainium and Inferentia. As you can see in the image below, it is composed of 4 compute engines. These engines are based on a systollic array architecture. The compute engines are fed data from the primary on-chip memory cache, SBUF. Data is moved from the HBM banks to SBUF when you call `nl.load`. You'll index into your tensors to create lower-level objects, called `tiles`. A tile is the result of `nl.load`. Once you've defined `tiles`, you can send them to various NKI mathematical APIS such as `add`, `subtract`, `matmul`, etc. The result of these operations are stored on the secondary on-chip memory cache, PSUM. After moving the data back to SBUF, you can then send it back to HBM with `nl.store`.

<img src=https://awsdocs-neuron.readthedocs-hosted.com/en/latest/_images/pm-nc.png width="400"/>

Trainium1 chips feature two NeuronCore-v2 acceleration units, 2 HBM banks, NeuronLink-v2 chip-to-chip connect, host PCIE, and dedicated engines for both data movement and collective communications. Trainium1 offers 32 GiB of device memory (sum of all 4 HBM banks), with 840 GiB/sec of bandwidth. Trainium1 instances feature 16 Trainium chips, providing a total of up to 3 petaflops of FP16 compute and 512 accelerator memory capacity. For more architectural details, see our docs [here](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/arch/neuron-hardware/trainium.html#trainium-arch). 


The on-chip memory cache, SBUF, **has ~20x higher memory bandwidth than HBM**. The purpose of your kernel is to exploit as much of that compute acceleration as you can within the context of your model and workload.

#### Structuring data and tensors for NKI

To easily move data and design our kernels on NKI, we'll want to exploit the 128 partitions built into SBUF as shown in the image below. In particular, SBUF has 128 partition lanes. Each of these lanes can execute programs in parallel on the engines. As much as possible, we'll want to align the tensors and data structures in our algorithms to follow this physical design. The benefit is that our kernels will run faster and be easier to develop!

Your data movement from HBM to SBUF should be very carefully aligned with this 128-lane partition dimension, also called p-dim. Each tile needs a precise definition along the p-dim. Your second dimension is called the free dimension, or f-dim. As the name goes, this dimension is much more flexible than p-dim. Though it may surprise you, it's better not to fully saturate sbuf with extremely large tiles. This is so that the compiler can overlap data movement and collectives with compute, giving you better overall compute utilization and performance.

<img src=https://awsdocs-neuron.readthedocs-hosted.com/en/latest/_images/pm-layout.png width="600"/>

### 2. Your first NKI kernel
Now that you have some understanding of the compute architecture and motivation for kernels, let's write your first NKI kernel! Importing the `nki` library may take a few moments the first time you've imported it on an instance.

In [1]:
import numpy as np
import neuronxcc.nki as nki
import neuronxcc.nki.language as nl

In [2]:
@nki.jit
def nki_tensor_add_kernel_(a_input, b_input):
 
  # Create output tensor 
  c_output = nl.ndarray(a_input.shape, dtype=a_input.dtype, buffer=nl.shared_hbm)

  # Load input data from device memory (HBM) to on-chip memory (SBUF)
  a_tile = nl.load(a_input)
  b_tile = nl.load(b_input)

  # compute a + b
  c_tile = a_tile + b_tile

  # return the final tensor
  nl.store(c_output, value=c_tile)

  # Transfer the ownership of `c_output` to the caller
  return c_output


In [3]:
a = np.random.rand(128, 512).astype(np.float16)
b = np.random.rand(128, 512).astype(np.float16)

output_nki = nki_tensor_add_kernel_(a, b)

output_np = a + b

allclose = np.allclose(output_np, output_nki, atol=1e-4, rtol=1e-2)
if allclose:
    print("NKI and NumPy match")
else:
    print("NKI and NumPy differ")


NKI and NumPy match


Now let's see if we can do that for matrix multiplication!

### 3. Your second NKI kernel
Now, let's try to use PyTorch arrays and pass them to the device with XLA. Then we'll try a matrix multiplication kernel.

If you get any errors, you may need to use a different python environment.  Choose the Python 3.9 kernel (or create a new venv) and install these packages:

%pip install neuronx-cc==2.18.121.0+9e31e41a

%pip install torch torch_xla torch_neuronx



In [4]:
import torch
from torch_xla.core import xla_model as xm

device = xm.xla_device()

lhs_small = torch.rand((64, 128), dtype=torch.bfloat16, device=device)
rhs_small = torch.rand((128, 512), dtype=torch.bfloat16, device=device)



In [5]:
@nki.jit
def nki_matmul_basic_(lhsT, rhs):
  """NKI kernel to compute a 64x128x512 matrix multiplication operation

  Args:
      lhsT: an input tensor of shape [128,64], a left hand side argument of the
        matrix multiplication, delivered transposed for optimal performance
      rhs: an input tensor of shape [128,512], a right hand side argument of the
        matrix multiplication
  Returns:
      result: the resulting output tensor of shape [64,512]
  """
  result = nl.ndarray((64, 512), dtype=lhsT.dtype, buffer=nl.shared_hbm)

  # Defining indexes for input LHS.T
  # - Note: here we take LayoutConstraint #1 into account:
  # "For MatMult, contraction axis must be mapped to P-dim"
  i_lhsT_p, i_lhsT_f = nl.mgrid[0:128, 0:64]

  # Defining indexes for input RHS
  # - Note: here we take LayoutConstraint #1 into account:
  # "For MatMult, contraction axis must be mapped to P-dim"
  i_rhs_p, i_rhs_f = nl.mgrid[0:128, 0:512]

  # Defining indexes for the output ([64,128]@[128,512] -> [64,512])
  i_out_p, i_out_f = nl.mgrid[0:64, 0:512]

  # Loading the inputs (HBM->SBUF)
  # Note: here we take Tile dtype definition into account,
  # which forces P-dim as the left most index
  lhs_tile = nl.load(lhsT[i_lhsT_p, i_lhsT_f])
  rhs_tile = nl.load(rhs[i_rhs_p, i_rhs_f])

  # Perform the matrix-multiplication
  # Note1: We set transpose_x to True, to indicate that the LHS input is transposed
  # Note2: A NKI matmul instruction always writes to PSUM in float32 data-type
  result_psum = nl.matmul(lhs_tile, rhs_tile, transpose_x=True)

  # Copy the result from PSUM back to SBUF, and cast to expected output data-type
  result_sbuf = nl.copy(result_psum, dtype=result.dtype)

  # The result of a [64,128] x [128,512] matrix multiplication has a shape of [64, 512].
  # This dictates which indices to use to address the result tile.
  nl.store(result[i_out_p, i_out_f], value=result_sbuf)

  return result

In [6]:
# Run NKI kernel
output_small = nki_matmul_basic_(lhs_small.T, rhs_small)

# Run torch reference
output_small_torch = torch.matmul(lhs_small, rhs_small)

# Compare results
print("Checking correctness of nki_matmul_basic")
if torch.allclose(output_small_torch, output_small, atol=1e-4, rtol=1e-2):
  print("NKI and Torch match")
else:
  print("NKI and Torch differ")

Checking correctness of nki_matmul_basic
2025-03-17 22:45:04.000657:  512118  INFO ||NEURON_CC_WRAPPER||: Call compiler with cmd: neuronx-cc compile --framework=XLA /tmp/ec2-user/neuroncc_compile_workdir/58a5f9b5-7dd1-4569-b58f-bae92b1f0d13/model.MODULE_6255296715421101974+e30acd3a.hlo_module.pb --output /tmp/ec2-user/neuroncc_compile_workdir/58a5f9b5-7dd1-4569-b58f-bae92b1f0d13/model.MODULE_6255296715421101974+e30acd3a.neff --target=trn1 --verbose=35
.
Compiler status PASS
NKI and Torch match


### 4. Wrap up and next steps
The simplicity you see in the `tensor_add` kernel above is possible because the shapes we pass in are very small. We've intentionally selected them to exactly match the shapes of tiles that NKI supports as maximum dimensions, for both the partition and free dimensions.

As you saw above, the partition dimension has a maximum length of 128. This the most important dimension and shape to embrace in your kernels, because it impacts your ability to load data onto the chip. In order to exploit the parallelism of execution enabled through the 128 lanes on sbuf, you might want to develop into your kernel the ability to extract data in batches of 128 to load onto sbuf. 

The second dimension, also known as the free dimension, is more flexible. Once you have clean batches of 128 lanes being loaded onto sbuf, you can build in tiling on the second dimension of much more varying sizes up to 512. 

To learn more about tiling, and to step through the rest of the matrix multiplication tutorial, see our docs on the topic [here](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/nki/tutorials/matrix_multiplication.html#)