In [17]:
import torch

In [18]:
A = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.float32)
B = torch.tensor([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12], [13, 14, 15, 16, 17, 18], [19, 20, 21, 22, 23, 24]], dtype=torch.float32)

**Example 1**

In [29]:
A, A.shape

(tensor([[1., 2., 3., 4.],
         [5., 6., 7., 8.]]),
 torch.Size([2, 4]))

In [30]:
B, B.shape

(tensor([[ 1.,  2.,  3.,  4.,  5.,  6.],
         [ 7.,  8.,  9., 10., 11., 12.],
         [13., 14., 15., 16., 17., 18.],
         [19., 20., 21., 22., 23., 24.]]),
 torch.Size([4, 6]))

Write a matrix multiplication between `A` and `B`

In [31]:
output = torch.zeros(2, 6)
n_rows, n_cols = output.shape
shared_dim = 4

In [32]:
for row_idx in range(n_rows):
    for column_idx in range(n_cols):
        for k in range(shared_dim):
            output[row_idx, column_idx] += A[row_idx, k] * B[k, column_idx]

In [33]:
output == A @ B

tensor([[True, True, True, True, True, True],
        [True, True, True, True, True, True]])

**Example 2**

In [42]:
# Example usage with PyTorch tensors
# Width = 4
# A = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], dtype=torch.float32)
# B = torch.tensor([[16, 15, 14, 13], [12, 11, 10, 9], [8, 7, 6, 5], [4, 3, 2, 1]], dtype=torch.float32)

A = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.float32)
B = torch.tensor([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12], [13, 14, 15, 16, 17, 18], [19, 20, 21, 22, 23, 24]], dtype=torch.float32)

In [57]:
B

tensor([[ 1.,  2.,  3.,  4.,  5.,  6.],
        [ 7.,  8.,  9., 10., 11., 12.],
        [13., 14., 15., 16., 17., 18.],
        [19., 20., 21., 22., 23., 24.]])

In [43]:
A.shape, B.shape

(torch.Size([2, 4]), torch.Size([4, 6]))

In [44]:
import torch

In [52]:
output = torch.zeros((2, 6), dtype=torch.float32)
n_rows, n_cols = output.shape
shared_dim = 4

In [53]:
# Flatten the matrices to simulate 1D array indexing like in the CUDA code
A_flat = A.view(-1)
B_flat = B.view(-1)

In [55]:
# Loop over each element of the result matrix P
for row_idx in range(n_rows):  # Equivalent to thread row index
    for column_idx in range(n_cols):  # Equivalent to thread column index
        P_value = 0
        # Compute the dot product for the current element in P
        for k in range(shared_dim):  # Loop over the shared dimension
            P_value += A_flat[row_idx * n_rows + k] * B_flat[k * n_cols + column_idx]
        # Store the result in P
        output[row_idx, column_idx] = P_value

In [56]:
output

tensor([[130., 140., 150., 160., 170., 180.],
        [210., 228., 246., 264., 282., 300.]])

##### Example 3

In [3]:
A = torch.arange(N_ROWS*SHARED_DIM).reshape(N_ROWS, SHARED_DIM)
B = torch.arange(SHARED_DIM*N_COLUMNS).reshape(SHARED_DIM, N_COLUMNS)

In [4]:
import torch

In [5]:
N_ROWS, N_COLUMNS, SHARED_DIM = 4, 8, 16

In [6]:
A.shape, B.shape

(torch.Size([4, 16]), torch.Size([16, 8]))

In [7]:
TILE_SIZE = 2

Implement tiled matrix multiplication given each tile has size `(TILE_SIZE, TILE_SIZE) `

In [8]:
ROW_BLOCK_SIZE = N_ROWS // TILE_SIZE
COLUMN_BLOCK_SIZE = N_COLUMNS // TILE_SIZE
SHARED_DIM_BLOCK_SIZE = SHARED_DIM // TILE_SIZE

In [9]:
output = torch.zeros(N_ROWS, N_COLUMNS)

In [10]:
for row_start_idx in range(0, N_ROWS, ROW_BLOCK_SIZE):
    row_end_idx = row_start_idx + ROW_BLOCK_SIZE
    
    for column_start_idx in range(0, N_COLUMNS, COLUMN_BLOCK_SIZE):
        column_end_idx = column_start_idx + COLUMN_BLOCK_SIZE
        
        accum = torch.zeros(ROW_BLOCK_SIZE, COLUMN_BLOCK_SIZE)
        for k_start_idx in range(0, SHARED_DIM, SHARED_DIM_BLOCK_SIZE):
            k_end_idx = k_start_idx + SHARED_DIM_BLOCK_SIZE
            
            tiled_a = A[row_start_idx:row_end_idx, k_start_idx:k_end_idx]
            tiled_b = B[k_start_idx:k_end_idx, column_start_idx:column_end_idx]
            accum += tiled_a @ tiled_b

        output[row_start_idx:row_end_idx, column_start_idx:column_end_idx] = accum

In [11]:
output == A@B

tensor([[True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True]])

In [13]:
2 % 2

0

In [15]:
torch.tensor([1, 2, 3, 4])[:, None]

tensor([[1],
        [2],
        [3],
        [4]])

In [16]:
torch.tensor([1, 2, 3, 4])[None, :]

tensor([[1, 2, 3, 4]])

##### Example 4

In [17]:
n_rows, n_cols = 4, 8
n_shared_dim = 16
block_size = 2

n_blocks_in_shared_dim = n_shared_dim // block_size

pids = torch.arange(n_rows*n_cols)

for pid in pids:
    row_start_idx = pid // n_blocks_in_shared_dim
    column_start_idx = pid % n_blocks_in_shared_dim

    print(row_start_idx, column_start_idx)

    # offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M))
    # offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N))
    
    # offs_k = tl.arange(0, BLOCK_SIZE_K)
    # a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
    # b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)

In [19]:
n_blocks_in_shared_dim = n_shared_dim // block_size

In [23]:
pids = torch.arange(n_rows*n_cols)

In [26]:
for pid in pids:
    row_start_idx = pid // n_blocks_in_shared_dim
    column_start_idx = pid % n_blocks_in_shared_dim

    print(row_start_idx, column_start_idx)

tensor(0) tensor(0)
tensor(0) tensor(1)
tensor(0) tensor(2)
tensor(0) tensor(3)
tensor(0) tensor(4)
tensor(0) tensor(5)
tensor(0) tensor(6)
tensor(0) tensor(7)
tensor(1) tensor(0)
tensor(1) tensor(1)
tensor(1) tensor(2)
tensor(1) tensor(3)
tensor(1) tensor(4)
tensor(1) tensor(5)
tensor(1) tensor(6)
tensor(1) tensor(7)
tensor(2) tensor(0)
tensor(2) tensor(1)
tensor(2) tensor(2)
tensor(2) tensor(3)
tensor(2) tensor(4)
tensor(2) tensor(5)
tensor(2) tensor(6)
tensor(2) tensor(7)
tensor(3) tensor(0)
tensor(3) tensor(1)
tensor(3) tensor(2)
tensor(3) tensor(3)
tensor(3) tensor(4)
tensor(3) tensor(5)
tensor(3) tensor(6)
tensor(3) tensor(7)


In [None]:

# grid_n = tl.cdiv(N, BLOCK_SIZE_N)
# pid_m = pid // grid_n
# pid_n = pid % grid_n

# ----------------------------------------------------------
# Create pointers for the first blocks of A and B.
# We will advance this pointer as we move in the K direction
# and accumulate
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
# See above `Pointer Arithmetic` section for details
# offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
# offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N

offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M))
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N))

offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)

In [29]:
import pandas as pd
import numpy as np

# Matrix dimensions
M = 4  # Number of rows in A and C
N = 8  # Number of columns in B and C
K = 16  # Number of columns in A and rows in B

# Block sizes
BLOCK_SIZE_M = 2
BLOCK_SIZE_N = 2
BLOCK_SIZE_K = 2

# Strides for row-major storage
stride_am = K  # Number of elements to skip to move to the next row in A
stride_ak = 1  # Number of elements to skip to move to the next column in A

stride_bk = N  # Number of elements to skip to move to the next row in B
stride_bn = 1  # Number of elements to skip to move to the next column in B

# Starting addresses (for simplicity, we start at 0)
a_ptr = 0
b_ptr = 0

# Calculate grid dimensions
grid_m = (M + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M  # ceil division
grid_n = (N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N  # ceil division

n_programs = grid_m * grid_n  # Total number of program instances

# Program IDs
pids = np.arange(n_programs)

# Compute pid_m and pid_n for each program
pid_ms = pids // grid_n
pid_ns = pids % grid_n

# Data list to collect intermediate variables
data = []

for pid, pid_m, pid_n in zip(pids, pid_ms, pid_ns):
    # Compute offsets for A
    offs_am = (pid_m * BLOCK_SIZE_M + np.arange(0, BLOCK_SIZE_M)) % M
    offs_k = np.arange(0, BLOCK_SIZE_K)

    # Compute pointers for A
    a_offsets_row = offs_am[:, None] * stride_am  # Shape: (BLOCK_SIZE_M, 1)
    a_offsets_col = offs_k[None, :] * stride_ak   # Shape: (1, BLOCK_SIZE_K)
    a_ptrs = a_ptr + a_offsets_row + a_offsets_col  # Shape: (BLOCK_SIZE_M, BLOCK_SIZE_K)

    # Compute offsets for B
    offs_bn = (pid_n * BLOCK_SIZE_N + np.arange(0, BLOCK_SIZE_N)) % N

    # Compute pointers for B
    b_offsets_row = offs_k[:, None] * stride_bk    # Shape: (BLOCK_SIZE_K, 1)
    b_offsets_col = offs_bn[None, :] * stride_bn   # Shape: (1, BLOCK_SIZE_N)
    b_ptrs = b_ptr + b_offsets_row + b_offsets_col  # Shape: (BLOCK_SIZE_K, BLOCK_SIZE_N)

    # Collect data
    data.append({
        'pid': pid,
        'pid_m': pid_m,
        'pid_n': pid_n,
        'offs_am': offs_am.tolist(),
        'offs_bn': offs_bn.tolist(),
        'offs_k': offs_k.tolist(),
        'a_ptrs': a_ptrs.tolist(),
        'b_ptrs': b_ptrs.tolist(),
    })

# Create a pandas DataFrame from the data
df = pd.DataFrame(data)


In [30]:
df

Unnamed: 0,pid,pid_m,pid_n,offs_am,offs_bn,offs_k,a_ptrs,b_ptrs
0,0,0,0,"[0, 1]","[0, 1]","[0, 1]","[[0, 1], [16, 17]]","[[0, 1], [8, 9]]"
1,1,0,1,"[0, 1]","[2, 3]","[0, 1]","[[0, 1], [16, 17]]","[[2, 3], [10, 11]]"
2,2,0,2,"[0, 1]","[4, 5]","[0, 1]","[[0, 1], [16, 17]]","[[4, 5], [12, 13]]"
3,3,0,3,"[0, 1]","[6, 7]","[0, 1]","[[0, 1], [16, 17]]","[[6, 7], [14, 15]]"
4,4,1,0,"[2, 3]","[0, 1]","[0, 1]","[[32, 33], [48, 49]]","[[0, 1], [8, 9]]"
5,5,1,1,"[2, 3]","[2, 3]","[0, 1]","[[32, 33], [48, 49]]","[[2, 3], [10, 11]]"
6,6,1,2,"[2, 3]","[4, 5]","[0, 1]","[[32, 33], [48, 49]]","[[4, 5], [12, 13]]"
7,7,1,3,"[2, 3]","[6, 7]","[0, 1]","[[32, 33], [48, 49]]","[[6, 7], [14, 15]]"


In [31]:
2//2

1