In [1]:
import torch
import triton
import triton.language as tl

torch.set_printoptions(linewidth=200, threshold=100000)

# Understand Block pointers in Triton

In this notebook I want to understand how exactly the loading of a matrix and directly transpose it works. 
I have used this in several implementations but find it still unintuitive. 

What does the order argument do???
In this case it does not make a difference at all. Why?

Okay so block_shape and order apparently do not make a difference for the outputs but help the compiler! See https://github.com/triton-lang/triton/issues/3890#issuecomment-2143708987 

This explains...

In [2]:
S = 16  # seq len
B = 1  # batch size
NH = 1  # num heads
DHQK = 4  # dim per head
# DHV = 8

CHUNK_SIZE = DHQK  # 4 #S
NC = S // CHUNK_SIZE

DTYPE = torch.float32
DEVICE = torch.device("cuda:0")

In [3]:
matIn = torch.arange(B * NH * S * DHQK).reshape(B, NH, NC * CHUNK_SIZE, DHQK).to(dtype=DTYPE, device=DEVICE)
# matOut = torch.zeros(B, NH, NC, DHQK, CHUNK_SIZE, dtype=DTYPE, device=DEVICE)
matOut = torch.zeros_like(matIn)
# matIn

In [6]:
@triton.jit
def matrix_load_transpose_kernel(
    matIn,
    matOut,
    str_BNH,
    str_S,
    str_DHQK,
    B: tl.constexpr,
    NH: tl.constexpr,
    S: tl.constexpr,
    DHQK: tl.constexpr,
    CHUNK_SIZE: tl.constexpr,
):
    idx_b_NC, idx_b_BNH = tl.program_id(0), tl.program_id(1)
    #! no transpose
    # matIn_ptr = tl.make_block_ptr(
    #     base=matIn + idx_b_BNH * str_BNH,
    #     shape=(S, DHQK),
    #     strides=(str_S, str_DHQK),
    #     offsets=(idx_b_NC * CHUNK_SIZE, 0),
    #     block_shape=(CHUNK_SIZE, DHQK),
    #     order=(0,1),#(1, 0), #? Does not play a role, both yield same result
    # )
    #! transpose per chunk
    matIn_ptr = tl.make_block_ptr(
        base=matIn + idx_b_BNH * str_BNH,
        shape=(DHQK, S),
        strides=(str_DHQK, str_S),
        offsets=(0, idx_b_NC * CHUNK_SIZE),  # (idx_b_NC * CHUNK_SIZE, 0),
        block_shape=(CHUNK_SIZE, DHQK),
        order=(1, 0),  # (0,1),#(1, 0), #? Does not play a role, both yield same result
    )

    matIn_val = tl.load(matIn_ptr)

    matOut_ptr = tl.make_block_ptr(
        base=matOut + idx_b_BNH * str_BNH,
        shape=(S, DHQK),
        strides=(str_S, str_DHQK),
        offsets=(idx_b_NC * CHUNK_SIZE, 0),
        block_shape=(CHUNK_SIZE, DHQK),
        order=(0, 1),
    )
    tl.store(matOut_ptr, matIn_val)

In [7]:
grid = (CHUNK_SIZE, B * NH)
matrix_load_transpose_kernel[grid](
    matIn,
    matOut,
    matIn.stride(1),
    matIn.stride(2),
    matIn.stride(3),
    B,
    NH,
    S,
    DHQK,
    CHUNK_SIZE,
)

<triton.compiler.compiler.CompiledKernel at 0x744dfb191b50>

In [8]:
matIn

tensor([[[[ 0.,  1.,  2.,  3.],
          [ 4.,  5.,  6.,  7.],
          [ 8.,  9., 10., 11.],
          [12., 13., 14., 15.],
          [16., 17., 18., 19.],
          [20., 21., 22., 23.],
          [24., 25., 26., 27.],
          [28., 29., 30., 31.],
          [32., 33., 34., 35.],
          [36., 37., 38., 39.],
          [40., 41., 42., 43.],
          [44., 45., 46., 47.],
          [48., 49., 50., 51.],
          [52., 53., 54., 55.],
          [56., 57., 58., 59.],
          [60., 61., 62., 63.]]]], device='cuda:0')

In [9]:
matOut

tensor([[[[ 0.,  4.,  8., 12.],
          [ 1.,  5.,  9., 13.],
          [ 2.,  6., 10., 14.],
          [ 3.,  7., 11., 15.],
          [16., 20., 24., 28.],
          [17., 21., 25., 29.],
          [18., 22., 26., 30.],
          [19., 23., 27., 31.],
          [32., 36., 40., 44.],
          [33., 37., 41., 45.],
          [34., 38., 42., 46.],
          [35., 39., 43., 47.],
          [48., 52., 56., 60.],
          [49., 53., 57., 61.],
          [50., 54., 58., 62.],
          [51., 55., 59., 63.]]]], device='cuda:0')

In [8]:
matIn - matOut

tensor([[[[ 0., -3., -6., -9.],
          [ 3.,  0., -3., -6.],
          [ 6.,  3.,  0., -3.],
          [ 9.,  6.,  3.,  0.],
          [ 0., -3., -6., -9.],
          [ 3.,  0., -3., -6.],
          [ 6.,  3.,  0., -3.],
          [ 9.,  6.,  3.,  0.],
          [ 0., -3., -6., -9.],
          [ 3.,  0., -3., -6.],
          [ 6.,  3.,  0., -3.],
          [ 9.,  6.,  3.,  0.],
          [ 0., -3., -6., -9.],
          [ 3.,  0., -3., -6.],
          [ 6.,  3.,  0., -3.],
          [ 9.,  6.,  3.,  0.]]]], device='cuda:0')

In [9]:
# tensor([[[[ 0.,  4.,  8., 12.],
#           [ 1.,  5.,  9., 13.],
#           [ 2.,  6., 10., 14.],
#           [ 3.,  7., 11., 15.],
#           [16., 20., 24., 28.],
#           [17., 21., 25., 29.],
#           [18., 22., 26., 30.],
#           [19., 23., 27., 31.],
#           [32., 36., 40., 44.],
#           [33., 37., 41., 45.],
#           [34., 38., 42., 46.],
#           [35., 39., 43., 47.],
#           [48., 52., 56., 60.],
#           [49., 53., 57., 61.],
#           [50., 54., 58., 62.],
#           [51., 55., 59., 63.]]]], device='cuda:0')