In [1]:
# Simulate Triton's matrix multiplication tile mapping
M = 8
N = 8
K = 8
BLOCK_M = 4
BLOCK_N = 4
BLOCK_K = 4
GROUP_M = 8

import math

def ceil_div(a, b):
    return -(-a // b)

grid_m = ceil_div(M, BLOCK_M)
grid_n = ceil_div(N, BLOCK_N)

print("Without reordering:")
for pid in range(grid_m * grid_n):
    # Calculate tile indices without reordering
    pid_m = pid // grid_n
    pid_n = pid % grid_n

    # Compute row indices (rm) and column indices (rn) for this tile
    rm = [pid_m * BLOCK_M + i for i in range(BLOCK_M)]
    rn = [pid_n * BLOCK_N + i for i in range(BLOCK_N)]

    # Ensure indices are within matrix dimensions
    rm = [i % M for i in rm]
    rn = [i % N for i in rn]

    print(f"\nProgram ID (pid): {pid}")
    print(f"Tile indices without reordering -> pid_m: {pid_m}, pid_n: {pid_n}")
    print(f"Row indices (rm): {rm}")
    print(f"Column indices (rn): {rn}")

print("\nWith reordering:")
width = GROUP_M * grid_n
for pid in range(grid_m * grid_n):
    # Compute group_id
    group_id = pid // width
    # Compute group_size
    group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
    # Calculate pid_m and pid_n with reordering
    pid_m = group_id * GROUP_M + (pid % group_size)
    pid_n = (pid % width) // group_size

    # Compute row indices (rm) and column indices (rn) for this tile
    rm = [pid_m * BLOCK_M + i for i in range(BLOCK_M)]
    rn = [pid_n * BLOCK_N + i for i in range(BLOCK_N)]

    # Ensure indices are within matrix dimensions
    rm = [i % M for i in rm]
    rn = [i % N for i in rn]

    print(f"\nProgram ID (pid): {pid}")
    print(f"Tile indices with reordering -> pid_m: {pid_m}, pid_n: {pid_n}")
    print(f"Row indices (rm): {rm}")
    print(f"Column indices (rn): {rn}")


Without reordering:

Program ID (pid): 0
Tile indices without reordering -> pid_m: 0, pid_n: 0
Row indices (rm): [0, 1, 2, 3]
Column indices (rn): [0, 1, 2, 3]

Program ID (pid): 1
Tile indices without reordering -> pid_m: 0, pid_n: 1
Row indices (rm): [0, 1, 2, 3]
Column indices (rn): [4, 5, 6, 7]

Program ID (pid): 2
Tile indices without reordering -> pid_m: 1, pid_n: 0
Row indices (rm): [4, 5, 6, 7]
Column indices (rn): [0, 1, 2, 3]

Program ID (pid): 3
Tile indices without reordering -> pid_m: 1, pid_n: 1
Row indices (rm): [4, 5, 6, 7]
Column indices (rn): [4, 5, 6, 7]

With reordering:

Program ID (pid): 0
Tile indices with reordering -> pid_m: 0, pid_n: 0
Row indices (rm): [0, 1, 2, 3]
Column indices (rn): [0, 1, 2, 3]

Program ID (pid): 1
Tile indices with reordering -> pid_m: 1, pid_n: 0
Row indices (rm): [4, 5, 6, 7]
Column indices (rn): [0, 1, 2, 3]

Program ID (pid): 2
Tile indices with reordering -> pid_m: 0, pid_n: 1
Row indices (rm): [0, 1, 2, 3]
Column indices (rn): [4, 