In [1]:
import torch
import triton
import triton.language as tl
from triton.runtime import driver

In [2]:
DEVICE = torch.device(f"cuda:{torch.cuda.current_device()}")

In [3]:
properties = driver.active.utils.get_device_properties(DEVICE.index)
properties

{'max_shared_mem': 101376,
 'max_num_regs': 65536,
 'multiprocessor_count': 64,
 'warpSize': 32,
 'sm_clock_rate': 1695000,
 'mem_clock_rate': 8001000,
 'mem_bus_width': 384}

In [None]:
@triton.jit
def transpose_kernel(input_ptr: torch.Tensor, output_ptr: torch.Tensor, 
              num_rows: int, num_cols: int, 
              BLOCK_SIZE: tl.constexpr, num_stages: tl.constexpr):
    
    row_block = tl.program_id(0)
    col_block = tl.program_id(1)
    # we'll load (BLOCK_SIZE, BLOCK_SIZE) in shared memory and store (BLOCK_SIZE, BLOCK_SIZE) in output.
    row_start = row_block * BLOCK_SIZE
    col_start = col_block * BLOCK_SIZE
    row_offs = tl.arange(0, BLOCK_SIZE)
    col_offs = tl.arange(0, BLOCK_SIZE)
    rows = row_start + row_offs
    cols = col_start + col_offs
    input_ptrs = input_ptr + rows[None, :] * num_cols + cols[None, :]
    mask = (rows[:, None] < num_rows) & (cols[None, :] < num_cols)
    block = tl.load(input_ptrs) # (BLOCK_SIZE, BLOCK_SIZE)
    
    out_row_start = col_start
    out_col_start = row_start
    out_rows = out_row_start + col_offs  # Note the swap of row_offs and col_offs
    out_cols = out_col_start + row_offs  # Note the swap of row_offs and col_offs
    out_ptrs = output_ptr + out_rows[:, None] * num_rows + out_cols[None, :] # Note num_rows here
    out_mask = (out_rows[:, None] < num_cols) & (out_cols[None, :] < num_rows) # Note the swap of num_rows and num_cols

    tl.store(out_ptrs, block, mask=out_mask)
    

In [None]:
@triton.jit
def transpose_kernel(input_ptr: torch.Tensor, output_ptr: torch.Tensor,
              num_rows: int, num_cols: int,
              BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, num_stages: tl.constexpr):

    row_block = tl.program_id(0)
    col_block = tl.program_id(1)

    row_offs = tl.arange(0, BLOCK_SIZE_M)
    col_offs = tl.arange(0, BLOCK_SIZE_N)

    # Input indices
    rows = row_offs + row_block * BLOCK_SIZE_M
    cols = col_offs + col_block * BLOCK_SIZE_N
    # we load in transposed manner.
    input_ptrs = input_ptr + rows[None, :] * num_cols + cols[:, None]
    mask = (rows[None, :] < num_rows) & (cols[:, None] < num_cols)
    block = tl.load(input_ptrs, mask=mask)

    # Output indices - note the reversal of row and col block and offsets
    out_rows = col_block * BLOCK_SIZE_N + col_offs
    out_cols = row_block * BLOCK_SIZE_M + row_offs
    out_ptrs = output_ptr + out_rows[:, None] * num_cols + out_cols[None, :] # Output has num_cols rows
    out_mask = (out_rows[:, None] < num_cols) & (out_cols[None, :] < num_rows) # Output has num_rows cols

    tl.store(out_ptrs, block, mask=out_mask)

In [57]:
torch.manual_seed(0)
BLOCK_SIZE_M = 16  # or any appropriate block size
BLOCK_SIZE_N = 8
num_rows = 64
num_cols = 64
input_ptr = torch.randn((num_rows, num_cols), dtype=torch.float32, device=DEVICE)
output_ptr = torch.zeros((num_cols, num_rows), dtype=torch.float32, device=input_ptr.device) # Output dimensions are swapped
num_stages = 1
grid = (triton.cdiv(num_rows, BLOCK_SIZE_M), triton.cdiv(num_cols, BLOCK_SIZE_N)) # Grid based on input blocking
transpose_kernel[grid](input_ptr, output_ptr, num_rows, num_cols, 
                       BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N,
                       num_stages=num_stages)

# Verify the transpose
expected_output = input_ptr.T.contiguous()
error = torch.mean(torch.abs(expected_output - output_ptr))
print(f"Error: {error}")

# You can also visually inspect a small portion
if num_rows <= 8 and num_cols <= 8:
    print("Input:\n", input_ptr.cpu().numpy())
    print("Output:\n", output_ptr.cpu().numpy())
    print("Expected Output:\n", expected_output.cpu().numpy())

Error: 0.0


In [None]:
def transpose_kernel_np(input: np.ndarray, output: np.ndarray, num_rows: int, num_cols: int, BLOCK_SIZE_M: int, BLOCK_SIZE_N: int, row_block: int, col_block: int):
    input_row_offsets = np.arange(0, BLOCK_SIZE_M)[:, None]
    input_col_offsets = np.arange(0, BLOCK_SIZE_N)[None, :]
    
    row_ptrs = np.arange(0, BLOCK_SIZE_M)[None, :] + row_block * BLOCK_SIZE_M
    col_ptrs = np.arange(0, BLOCK_SIZE_N)[:, None] + col_block * BLOCK_SIZE_N

    block = y[row_ptrs, col_ptrs]
    out_row_ptrs = np.arange(0, BLOCK_SIZE_N)[:, None] + col_block * BLOCK_SIZE_N
    out_col_ptrs = np.arange(0, BLOCK_SIZE_M)[None, :] + row_block * BLOCK_SIZE_M
    output[out_row_ptrs, out_col_ptrs] = block

In [None]:
import numpy as np
BLOCK_SIZE_M, BLOCK_SIZE_N = 4, 2
y = np.arange(64).reshape(8, 8)
print(y)
num_row_blocks = num_rows // BLOCK_SIZE_M
num_col_blocks = num_cols // BLOCK_SIZE_N
output = np.empty((num_cols, num_rows), dtype=y.dtype)
for i in range(0, num_row_blocks):
    for j in range(0, num_col_blocks):
        transpose_kernel_np(y, output, num_rows, num_cols, BLOCK_SIZE_M, BLOCK_SIZE_N, i, j)

print(output)
np.alltrue(y.T == output)
# row_ptrs = input_row_offsets + row_block * BLOCK_SIZE_M
# col_ptrs = input_col_offsets + col_block * BLOCK_SIZE_N
# b = block.ravel()
# z = out.ravel()
# out = np.empty((c, r), dtype=y.dtype)
# ri = np.arange(0, BLOCK_SIZE_N)[:, None] * out.shape[1]
# ci = np.arange(0, BLOCK_SIZE_M)[None, :]
# oi = (ci + ri).ravel()
# z = out.ravel()
# z[oi] = b[ii]

[[ 0  1  2  3  4  5  6  7]
 [ 8  9 10 11 12 13 14 15]
 [16 17 18 19 20 21 22 23]
 [24 25 26 27 28 29 30 31]
 [32 33 34 35 36 37 38 39]
 [40 41 42 43 44 45 46 47]
 [48 49 50 51 52 53 54 55]
 [56 57 58 59 60 61 62 63]]
[[ 0  8 16 24 32 40 48 56]
 [ 1  9 17 25 33 41 49 57]
 [ 2 10 18 26 34 42 50 58]
 [ 3 11 19 27 35 43 51 59]
 [ 4 12 20 28 36 44 52 60]
 [ 5 13 21 29 37 45 53 61]
 [ 6 14 22 30 38 46 54 62]
 [ 7 15 23 31 39 47 55 63]]


True

In [23]:
torch.manual_seed(0)
BLOCK_SIZE = 64  # or any appropriate block size
input_ptr = torch.randn((BLOCK_SIZE, BLOCK_SIZE), dtype=torch.float32, device=DEVICE)
output_ptr = torch.zeros((BLOCK_SIZE, BLOCK_SIZE), dtype=torch.float32, device=input_ptr.device)
num_rows, num_cols = input_ptr.shape
num_stages = 1
grid = (triton.cdiv(num_rows, BLOCK_SIZE), triton.cdiv(num_cols, BLOCK_SIZE))
transpose_kernel[grid](input_ptr, output_ptr, num_rows, num_cols, BLOCK_SIZE=BLOCK_SIZE,
                       num_stages=num_stages)
torch.mean(torch.abs(input_ptr.T.contiguous() - output_ptr))

tensor(1.1017, device='cuda:0')

In [10]:
@triton.jit
def reshape_kernel(x_ptr, y_ptr, BLOCK_SIZE: tl.constexpr):
    x_offsets = (
        tl.arange(0, BLOCK_SIZE)[None, :] * BLOCK_SIZE
        + tl.arange(0, BLOCK_SIZE)[:, None]
    )
    y_offsets = (
        tl.arange(0, BLOCK_SIZE)[None, :]
        + tl.arange(0, BLOCK_SIZE)[:, None] * BLOCK_SIZE
    )
    x = tl.load(x_ptr + x_offsets)
    tl.store(y_ptr + y_offsets, x)


BLOCK_SIZE = 64
x = torch.randn((BLOCK_SIZE * BLOCK_SIZE), device="cuda", dtype=torch.float32).view(
    (BLOCK_SIZE, BLOCK_SIZE)
)
y = torch.zeros((BLOCK_SIZE * BLOCK_SIZE), device="cuda", dtype=torch.float32).view(
    (BLOCK_SIZE, BLOCK_SIZE)
)
reshape_kernel[(1,)](x, y, BLOCK_SIZE)

<triton.compiler.compiler.CompiledKernel at 0x7fe2364c7fa0>

In [74]:
torch.mean(torch.abs(x - y.T))

tensor(0., device='cuda:0')

<triton.compiler.compiler.CompiledKernel at 0x7fa060bbbb80>

In [76]:
torch.mean(torch.abs(input_ptr - output_ptr.T))

tensor(1.1017, device='cuda:0')

In [43]:
tl.trans?

[0;31mSignature:[0m [0mtl[0m[0;34m.[0m[0mtrans[0m[0;34m([0m[0minput[0m[0;34m:[0m [0;34m'tensor'[0m[0;34m,[0m [0;34m*[0m[0mdims[0m[0;34m,[0m [0m_builder[0m[0;34m=[0m[0;32mNone[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m
Permutes the dimensions of a tensor.

If no permutation is specified, tries to do a (1,0) permutation, i.e. tries
to transpose a 2D tensor.

:param input: The input tensor.
:param dims: The desired ordering of dimensions.  For example,
    :code:`(2, 1, 0)` reverses the order dims in a a 3D tensor.

:code:`dims` can be passed as a tuple or as individual parameters: ::

    # These are equivalent
    trans(x, (2, 1, 0))
    trans(x, 2, 1, 0)

:py:func:`permute` is equivalent to this function, except it doesn't
have the special case when no permutation is specified.

This function can also be called as a member function on :py:class:`tensor`,
as :code:`x.trans(...)` instead of
:code:`trans(x, ...)`.
[0;31mFile:[0m      ~/m